Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
......@@ -3,14 +3,13 @@
#include <nlohmann/json.hpp>
#include <mio/mmap.hpp>
using json = nlohmann::json;
using spdlog::fmt_lib::format;
class SafeTensors::MMapImpl {
public:
virtual ~MMapImpl() {}
virtual size_t size() = 0;
virtual size_t size() = 0;
virtual const char *data() = 0;
};
......@@ -55,7 +54,7 @@ private:
std::unique_ptr<Buffer> buffer;
};
#ifdef __linux__
#ifdef __linux__
#include <unistd.h>
#include <fcntl.h>
......@@ -97,7 +96,7 @@ private:
void *ptr;
};
#else
#else
class SafeTensors::MMapImplPrivate : public SafeTensors::MMapImpl {
public:
......@@ -117,33 +116,34 @@ public:
SafeTensors::SafeTensors(const std::string &filename) {
this->hostRegistered = false;
this->memoryPinned = false;
this->memoryPinned = false;
auto methodPrivate = [&]() {
this->mapped = std::make_unique<MMapImplPrivate>(filename);
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable));
checkCUDA(
cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable));
this->hostRegistered = true;
this->memoryPinned = true;
this->memoryPinned = true;
};
auto methodMio = [&]() {
this->mapped = std::make_unique<MMapImplMio>(filename);
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable | cudaHostRegisterReadOnly));
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()),
this->mapped->size(),
cudaHostRegisterPortable | cudaHostRegisterReadOnly));
this->hostRegistered = true;
this->memoryPinned = true;
this->memoryPinned = true;
};
auto methodRead = [&]() {
this->mapped = std::make_unique<MMapImplRead>(filename, true);
this->mapped = std::make_unique<MMapImplRead>(filename, true);
this->memoryPinned = true;
};
auto methodReadNopin = [&]() {
this->mapped = std::make_unique<MMapImplRead>(filename, false);
};
auto methodReadNopin = [&]() { this->mapped = std::make_unique<MMapImplRead>(filename, false); };
const std::map<std::string, std::function<void()>> methods = {
{ "PRIVATE", methodPrivate },
{ "MIO", methodMio },
{ "READ", methodRead },
{ "READNOPIN", methodReadNopin },
{"PRIVATE", methodPrivate},
{"MIO", methodMio},
{"READ", methodRead},
{"READNOPIN", methodReadNopin},
};
auto tryMethod = [&](std::string name) {
......@@ -168,7 +168,6 @@ SafeTensors::SafeTensors(const std::string &filename) {
#else
tryMethod("MIO") || tryMethod("READ") || tryMethod("READNOPIN");
#endif
}
if (!this->mapped) {
......@@ -192,19 +191,20 @@ SafeTensors::~SafeTensors() {
void SafeTensors::parseHeader() {
static const std::unordered_map<std::string, Tensor::ScalarType> mapDType = {
{ "BF16", Tensor::BF16 },
{ "F16", Tensor::FP16 },
{ "F32", Tensor::FP32 },
{ "I8", Tensor::INT8 },
{ "I32", Tensor::INT32 },
{ "I64", Tensor::INT64 },
{ "F8_E4M3", Tensor::FP8_E4M3 },
{ "F8_E5M2", Tensor::FP8_E5M2 },
{"BF16", Tensor::BF16},
{"F16", Tensor::FP16},
{"F32", Tensor::FP32},
{"I8", Tensor::INT8},
{"I32", Tensor::INT32},
{"I64", Tensor::INT64},
{"F8_E4M3", Tensor::FP8_E4M3},
{"F8_E5M2", Tensor::FP8_E5M2},
};
auto check = [](bool cond, std::source_location location = std::source_location::current()) {
if (!cond) {
throw std::runtime_error(format("Safetensors check failed at {}:{}", location.file_name(), location.line()));
throw std::runtime_error(
format("Safetensors check failed at {}:{}", location.file_name(), location.line()));
}
};
......@@ -222,8 +222,9 @@ void SafeTensors::parseHeader() {
continue;
}
auto dtype = mapDType.at(info["dtype"].get<std::string>());;
auto shape = info["shape"].get<std::vector<int>>();
auto dtype = mapDType.at(info["dtype"].get<std::string>());
;
auto shape = info["shape"].get<std::vector<int>>();
auto data_offsets = info["data_offsets"].get<std::vector<uint64_t>>();
check(data_offsets.size() == 2);
......@@ -235,8 +236,8 @@ void SafeTensors::parseHeader() {
}
TensorInfo tinfo;
tinfo.type = dtype;
tinfo.shape = TensorShape(shape);
tinfo.type = dtype;
tinfo.shape = TensorShape(shape);
tinfo.length = data_offsets[1] - data_offsets[0];
tinfo.offset = 8 + sizeHeader + data_offsets[0];
......@@ -258,15 +259,15 @@ Tensor SafeTensors::getTensor(const std::string &key) {
std::shared_ptr<BufferMMap> buffer = info.buffer.lock();
if (!buffer) {
buffer = std::make_shared<BufferMMap>(const_cast<char *>(this->mapped->data() + info.offset), info.length, shared_from_this());
buffer = std::make_shared<BufferMMap>(
const_cast<char *>(this->mapped->data() + info.offset), info.length, shared_from_this());
info.buffer = buffer;
}
Tensor result;
result.shape = info.shape;
result.shape = info.shape;
result.scalarType = info.type;
result.buffer = buffer;
result.buffer = buffer;
return result;
}
......@@ -6,15 +6,15 @@
class BufferMMap : public Buffer {
public:
BufferMMap(void *ptr, size_t size, std::shared_ptr<void> parent) : parent(parent) {
this->size = size;
this->size = size;
this->device.type = Device::CPU;
this->ptr = ptr;
this->ptr = ptr;
// auto ret = cudaHostRegister(ptr, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
// if (ret == cudaSuccess) {
// this->registered = true;
// } else {
// log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size, cudaGetErrorString(cudaGetLastError())));
// this->registered = false;
// log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size,
// cudaGetErrorString(cudaGetLastError()))); this->registered = false;
// }
}
virtual ~BufferMMap() {
......@@ -22,6 +22,7 @@ public:
// checkCUDA(cudaHostUnregister(ptr));
// }
}
public:
std::shared_ptr<void> parent;
// bool registered;
......@@ -32,7 +33,7 @@ public:
SafeTensors(const std::string &filename);
~SafeTensors();
virtual bool contains(const std::string &key) const override {
virtual bool contains(const std::string &key) const override {
return tensors.contains(key);
}
virtual Tensor getTensor(const std::string &key) override;
......@@ -57,4 +58,4 @@ private:
std::unique_ptr<MMapImpl> mapped;
bool hostRegistered, memoryPinned;
};
\ No newline at end of file
};
......@@ -3,13 +3,10 @@
#include "common.h"
struct Device {
enum Type {
INVALID_DEVICE_TYPE = 0,
CPU, CUDA
};
enum Type { INVALID_DEVICE_TYPE = 0, CPU, CUDA };
Type type = INVALID_DEVICE_TYPE;
int idx = 0;
int idx = 0;
static constexpr Device cpu(int idx = 0) {
return Device{CPU, idx};
......@@ -23,21 +20,29 @@ struct Device {
class Buffer : public std::enable_shared_from_this<Buffer> {
public:
virtual ~Buffer() {}
void *getPtr() { return ptr; }
void *getPtr() {
return ptr;
}
template<typename T>
T *getPtr() { return reinterpret_cast<T *>(ptr); }
T *getPtr() {
return reinterpret_cast<T *>(ptr);
}
size_t getSize() { return size; }
Device getDevice() { return device; }
size_t getSize() {
return size;
}
Device getDevice() {
return device;
}
virtual bool isAsyncBuffer() {
virtual bool isAsyncBuffer() {
return false;
}
protected:
template <typename Derived>
template<typename Derived>
std::shared_ptr<Derived> shared_from_base() {
return std::static_pointer_cast<Derived>(shared_from_this());
}
......@@ -55,9 +60,9 @@ protected:
class BufferMalloc : public Buffer {
public:
BufferMalloc(size_t size) {
this->size = size;
this->size = size;
this->device.type = Device::CPU;
this->ptr = malloc(size);
this->ptr = malloc(size);
}
virtual ~BufferMalloc() {
free(this->ptr);
......@@ -67,7 +72,7 @@ public:
class BufferHost : public Buffer {
public:
BufferHost(size_t size) {
this->size = size;
this->size = size;
this->device.type = Device::CPU;
checkCUDA(cudaHostAlloc(&this->ptr, size, cudaHostAllocPortable));
}
......@@ -79,7 +84,7 @@ public:
class BufferCUDA : public Buffer {
public:
BufferCUDA(size_t size) {
this->size = size;
this->size = size;
this->device.type = Device::CUDA;
// checkCUDA(cudaGetDevice(&this->device.idx));
this->device.idx = CUDADeviceContext::getDevice();
......@@ -96,7 +101,7 @@ public:
}
checkCUDA(cudaFreeAsync(this->ptr, getCurrentCUDAStream()));
}
virtual bool isAsyncBuffer() override {
virtual bool isAsyncBuffer() override {
return true;
}
};
......@@ -104,7 +109,7 @@ public:
class BufferCUDASync : public Buffer {
public:
BufferCUDASync(size_t size) {
this->size = size;
this->size = size;
this->device.type = Device::CUDA;
checkCUDA(cudaGetDevice(&this->device.idx));
checkCUDA(cudaMalloc(&this->ptr, size));
......@@ -118,8 +123,8 @@ class BufferView : public Buffer {
public:
BufferView(std::shared_ptr<Buffer> reference, size_t offset, size_t size) : reference(reference) {
assert(offset + size <= reference->getSize());
this->ptr = (void *)((std::uint8_t *)reference->getPtr() + offset);
this->size = size;
this->ptr = (void *)((std::uint8_t *)reference->getPtr() + offset);
this->size = size;
this->device = reference->getDevice();
}
......@@ -213,23 +218,31 @@ struct TensorShape {
}
};
class Tensor {
public:
enum ScalarType {
INVALID_SCALAR_TYPE,
INT8, INT16, INT32, INT64,
FP16, FP32, BF16,
FP8_E4M3, FP8_E5M2,
INT8,
INT16,
INT32,
INT64,
FP16,
FP32,
BF16,
FP8_E4M3,
FP8_E5M2,
};
struct TensorOptions {
Device device_;
ScalarType dtype_;
Device device() const { return device_; }
ScalarType dtype() const { return dtype_; }
Device device() const {
return device_;
}
ScalarType dtype() const {
return dtype_;
}
TensorOptions device(Device dev) const {
TensorOptions result(*this);
......@@ -244,56 +257,95 @@ public:
};
static const std::map<ScalarType, size_t> scalarSize;
public:
TensorShape shape;
ScalarType scalarType;
std::shared_ptr<Buffer> buffer;
public:
bool valid() const { return shape.dataExtent.size() > 0; }
int size(int dim) const { return shape[dim]; }
bool is_contiguous() const { return shape.is_contiguous(); }
std::vector<int> sizes() const { return shape.dataExtent; }
bool valid() const {
return shape.dataExtent.size() > 0;
}
int size(int dim) const {
return shape[dim];
}
bool is_contiguous() const {
return shape.is_contiguous();
}
std::vector<int> sizes() const {
return shape.dataExtent;
}
bool is_cuda() const { return device().type == Device::CUDA; }
bool is_cuda() const {
return device().type == Device::CUDA;
}
TensorOptions options() const { return TensorOptions{device(), dtype()}; }
int get_device() const { return device().idx; }
TensorOptions options() const {
return TensorOptions{device(), dtype()};
}
int get_device() const {
return device().idx;
}
template<typename T>
T *data_ptr() { return reinterpret_cast<T*>(data_ptr()); }
T *data_ptr() {
return reinterpret_cast<T *>(data_ptr());
}
template<typename T>
const T *data_ptr() const { return reinterpret_cast<const T*>(data_ptr()); }
const void *data_ptr() const { return buffer->getPtr<char>() + shape.offset * scalar_size(); }
void *data_ptr() { return buffer->getPtr<char>() + shape.offset * scalar_size(); }
const T *data_ptr() const {
return reinterpret_cast<const T *>(data_ptr());
}
Device device() const { return buffer->getDevice(); }
const void *data_ptr() const {
return buffer->getPtr<char>() + shape.offset * scalar_size();
}
void *data_ptr() {
return buffer->getPtr<char>() + shape.offset * scalar_size();
}
ScalarType scalar_type() const { return scalarType; }
ScalarType dtype() const { return scalar_type(); }
Device device() const {
return buffer->getDevice();
}
ScalarType scalar_type() const {
return scalarType;
}
ScalarType dtype() const {
return scalar_type();
}
size_t stride(int dim) const { return shape.stride(dim); }
size_t stride(int dim) const {
return shape.stride(dim);
}
size_t numel() const { return shape.size(); }
size_t ndims() const { return shape.ndims(); }
size_t numel() const {
return shape.size();
}
size_t ndims() const {
return shape.ndims();
}
size_t dim() const { return ndims(); }
size_t dim() const {
return ndims();
}
size_t scalar_size() const { return scalarSize.at(scalarType); }
size_t scalar_size() const {
return scalarSize.at(scalarType);
}
Tensor operator[](int idx) const {
assert(ndims() > 1);
Tensor result;
result.shape = std::vector<int>(this->shape.dataExtent.begin() + 1, this->shape.dataExtent.end());
size_t size = stride(0) * scalar_size();
result.buffer = std::make_shared<BufferView>(this->buffer, idx * size, size);
result.shape = std::vector<int>(this->shape.dataExtent.begin() + 1, this->shape.dataExtent.end());
size_t size = stride(0) * scalar_size();
result.buffer = std::make_shared<BufferView>(this->buffer, idx * size, size);
result.scalarType = this->scalarType;
return result;
}
template<typename T>
const T & at(const std::vector<int> &idx) const {
const T &at(const std::vector<int> &idx) const {
assert(ndims() == idx.size());
int64_t offset = 0;
for (size_t i = 0; i < ndims(); i++) {
......@@ -304,17 +356,17 @@ public:
}
template<typename T>
T & at(const std::vector<int> &idx) {
T &at(const std::vector<int> &idx) {
return const_cast<T &>(const_cast<const Tensor *>(this)->at<T>(idx));
}
Tensor slice(int dim, int from, int to) const {
assert(from <= to);
Tensor result;
result.buffer = this->buffer;
result.buffer = this->buffer;
result.scalarType = this->scalarType;
result.shape = TensorShape(this->shape.dataExtent);
result.shape = TensorShape(this->shape.dataExtent);
result.shape[dim] = to - from;
result.shape.dataStride.resize(result.shape.ndims());
for (int i = 0; i < result.shape.ndims(); i++) {
......@@ -326,7 +378,7 @@ public:
}
Tensor transpose(int dim1, int dim2) const {
Tensor result;
result.buffer = this->buffer;
result.buffer = this->buffer;
result.scalarType = this->scalarType;
result.shape = TensorShape(this->shape.dataExtent);
......@@ -346,9 +398,9 @@ public:
assert(shape.size() == this->shape.size());
assert(this->is_contiguous());
Tensor result;
result.buffer = this->buffer;
result.scalarType = this->scalarType;
result.shape = shape;
result.buffer = this->buffer;
result.scalarType = this->scalarType;
result.shape = shape;
result.shape.offset = this->shape.offset;
return result;
}
......@@ -363,7 +415,8 @@ public:
Tensor &zero_() {
assert(this->is_contiguous());
checkCUDA(cudaMemsetAsync(data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentCUDAStream()));
checkCUDA(cudaMemsetAsync(
data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentCUDAStream()));
return *this;
}
Tensor &copy_(Tensor other) {
......@@ -380,23 +433,17 @@ public:
}
if (this->device().type == Device::CPU && other.device().type == Device::CPU) {
memcpy(
data_ptr<char>(),
other.data_ptr<char>(),
shape.size() * scalar_size()
);
memcpy(data_ptr<char>(), other.data_ptr<char>(), shape.size() * scalar_size());
return *this;
}
lockBuffer(this->buffer, getCurrentCUDAStream());
lockBuffer(other.buffer, getCurrentCUDAStream());
checkCUDA(cudaMemcpyAsync(
data_ptr<char>(),
other.data_ptr<char>(),
shape.size() * scalar_size(),
getCopyKind(this->device(), other.device()),
getCurrentCUDAStream()
));
checkCUDA(cudaMemcpyAsync(data_ptr<char>(),
other.data_ptr<char>(),
shape.size() * scalar_size(),
getCopyKind(this->device(), other.device()),
getCurrentCUDAStream()));
return *this;
}
......@@ -425,14 +472,15 @@ public:
assert(false);
}
result.scalarType = scalarType;
result.shape = shape;
result.shape = shape;
if (fill) {
if (device.type == Device::CPU) {
memset(result.buffer->getPtr(), 0xCC, result.buffer->getSize());
} else if (device.type == Device::CUDA) {
CUDADeviceContext ctx(device.idx);
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream()));
checkCUDA(
cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream()));
}
}
......@@ -450,11 +498,12 @@ public:
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 1, result.buffer->getSize(), getCurrentCUDAStream()));
return result;
}
static Tensor allocate_view(TensorShape shape, ScalarType scalarType, std::shared_ptr<Buffer> buffer, size_t offset = 0) {
static Tensor
allocate_view(TensorShape shape, ScalarType scalarType, std::shared_ptr<Buffer> buffer, size_t offset = 0) {
Tensor result;
result.buffer = std::make_shared<BufferView>(buffer, offset, shape.size() * scalarSize.at(scalarType));
result.buffer = std::make_shared<BufferView>(buffer, offset, shape.size() * scalarSize.at(scalarType));
result.scalarType = scalarType;
result.shape = shape;
result.shape = shape;
return result;
}
......@@ -468,13 +517,16 @@ public:
// lockBuffer(this->buffer, getCurrentCUDAStream());
// lockBuffer(result.buffer, getCurrentCUDAStream());
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault, getCurrentCUDAStream()));
// if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyHostToDevice, getCurrentCUDAStream()));
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault,
// getCurrentCUDAStream())); if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyHostToDevice, getCurrentCUDAStream()));
// } else if (this->device().type == Device::CUDA && device.type == Device::CPU) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDeviceToHost, getCurrentCUDAStream()));
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDeviceToHost, getCurrentCUDAStream()));
// } else {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault, getCurrentCUDAStream()));
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDefault, getCurrentCUDAStream()));
// }
return result;
}
......@@ -516,9 +568,10 @@ private:
// }
static inline std::map<cudaStream_t, std::set<std::shared_ptr<Buffer>>> lockedBuffers;
public:
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU completes
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU
// completes
static void lockBuffer(std::shared_ptr<Buffer> buffer, cudaStream_t stream) {
if (!buffer->isAsyncBuffer()) {
lockedBuffers[stream].insert(buffer);
......@@ -558,5 +611,5 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
struct TensorsProvider {
virtual ~TensorsProvider() {}
virtual bool contains(const std::string &key) const = 0;
virtual Tensor getTensor(const std::string &key) = 0;
};
\ No newline at end of file
virtual Tensor getTensor(const std::string &key) = 0;
};
......@@ -22,13 +22,15 @@ Tensor GELU::forward(Tensor x) {
// return out;
// }
// Tensor SiluAndMulQuant::forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor SiluAndMulQuant::forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor
// quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor out = SiluAndMul::forward(x);
// invoke_quant_fuse_sum(quantized_mlp_act_buffer, out, quantized_sum_buffer, quantized_scale_buffer);
// return out;
// }
// Tensor SiluAndMulQuant::forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor SiluAndMulQuant::forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer,
// Tensor quantized_sum_buffer) {
// Tensor out = SiluAndMul::forward(x);
// invoke_quant(quantized_mlp_act_buffer, out, quantized_scale_buffer, {});
// return out;
......
......@@ -20,7 +20,8 @@ public:
// class SiluAndMulQuant {
// public:
// static Tensor forward(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer, bool act_sum) {
// static Tensor forward(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor
// quantized_sum_buffer, bool act_sum) {
// if (act_sum) {
// return forward_with_act_sum(x, quantized_mlp_act_buffer, quantized_scale_buffer, quantized_sum_buffer);
// } else {
......@@ -28,6 +29,7 @@ public:
// }
// }
// private:
// static Tensor forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer);
// static Tensor forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer);
// };
\ No newline at end of file
// static Tensor forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer,
// Tensor quantized_sum_buffer); static Tensor forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor
// quantized_scale_buffer, Tensor quantized_sum_buffer);
// };
......@@ -25,7 +25,7 @@
class CUDAError : public std::runtime_error {
public:
CUDAError(cudaError_t errorCode, std::source_location location)
CUDAError(cudaError_t errorCode, std::source_location location)
: std::runtime_error(format(errorCode, location)), errorCode(errorCode), location(location) {}
public:
......@@ -34,12 +34,13 @@ public:
private:
static std::string format(cudaError_t errorCode, std::source_location location) {
return spdlog::fmt_lib::format("CUDA error: {} (at {}:{})",
cudaGetErrorString(errorCode), location.file_name(), location.line());
return spdlog::fmt_lib::format(
"CUDA error: {} (at {}:{})", cudaGetErrorString(errorCode), location.file_name(), location.line());
}
};
inline cudaError_t checkCUDA(cudaError_t retValue, const std::source_location location = std::source_location::current()) {
inline cudaError_t checkCUDA(cudaError_t retValue,
const std::source_location location = std::source_location::current()) {
if (retValue != cudaSuccess) {
(void)cudaGetLastError();
throw CUDAError(retValue, location);
......@@ -47,10 +48,11 @@ inline cudaError_t checkCUDA(cudaError_t retValue, const std::source_location lo
return retValue;
}
inline cublasStatus_t checkCUBLAS(cublasStatus_t retValue, const std::source_location location = std::source_location::current()) {
inline cublasStatus_t checkCUBLAS(cublasStatus_t retValue,
const std::source_location location = std::source_location::current()) {
if (retValue != CUBLAS_STATUS_SUCCESS) {
throw std::runtime_error(spdlog::fmt_lib::format("CUBLAS error: {} (at {}:{})",
cublasGetStatusString(retValue), location.file_name(), location.line()));
throw std::runtime_error(spdlog::fmt_lib::format(
"CUBLAS error: {} (at {}:{})", cublasGetStatusString(retValue), location.file_name(), location.line()));
}
return retValue;
}
......@@ -71,8 +73,8 @@ struct CUDAStreamContext {
stackCUDAStreams.push(stream);
}
CUDAStreamContext(const CUDAStreamContext &) = delete;
CUDAStreamContext(CUDAStreamContext &&) = delete;
CUDAStreamContext(CUDAStreamContext &&) = delete;
~CUDAStreamContext() {
assert(stackCUDAStreams.top() == stream);
stackCUDAStreams.pop();
......@@ -86,7 +88,7 @@ struct CUDAStreamWrapper {
checkCUDA(cudaStreamCreate(&stream));
}
CUDAStreamWrapper(const CUDAStreamWrapper &) = delete;
CUDAStreamWrapper(CUDAStreamWrapper &&) = delete;
CUDAStreamWrapper(CUDAStreamWrapper &&) = delete;
~CUDAStreamWrapper() {
checkCUDA(cudaStreamDestroy(stream));
......@@ -100,14 +102,13 @@ struct CUDAEventWrapper {
checkCUDA(cudaEventCreateWithFlags(&event, flags));
}
CUDAEventWrapper(const CUDAEventWrapper &) = delete;
CUDAEventWrapper(CUDAEventWrapper &&) = delete;
CUDAEventWrapper(CUDAEventWrapper &&) = delete;
~CUDAEventWrapper() {
checkCUDA(cudaEventDestroy(event));
}
};
/**
* 1. hold one when entered from external code (set `device` to -1 to avoid device change)
* 2. hold one when switching device
......@@ -121,7 +122,7 @@ public:
// previous context is reset on => external code may be executed, reset
currentDeviceCache = -1;
}
ctxs.push(this);
lastDevice = getDevice();
if (device >= 0) {
......@@ -134,7 +135,7 @@ public:
}
}
CUDADeviceContext(const CUDADeviceContext &) = delete;
CUDADeviceContext(CUDADeviceContext &&) = delete;
CUDADeviceContext(CUDADeviceContext &&) = delete;
~CUDADeviceContext() {
if (disableCache) {
......@@ -148,7 +149,8 @@ public:
if (cacheDisabled()) {
// ctxs.empty() => we are about to return to external code, reset cache
// otherwise => we are a nested context in a previous context with reset on, we might continue to execute external code, reset
// otherwise => we are a nested context in a previous context with reset on, we might continue to execute
// external code, reset
currentDeviceCache = -1;
}
}
......@@ -156,7 +158,6 @@ public:
const bool disableCache;
int lastDevice;
public:
static int getDevice() {
int idx = -1;
......@@ -168,6 +169,7 @@ public:
currentDeviceCache = cacheDisabled() ? -1 : idx;
return idx;
}
private:
static void setDevice(int idx) {
// TODO: deal with stream when switching device
......@@ -207,11 +209,11 @@ constexpr T ceilDiv(T a, T b) {
template<typename T>
constexpr int log2Up(T value) {
if (value <= 0)
return 0;
if (value == 1)
return 0;
return log2Up((value + 1) / 2) + 1;
if (value <= 0)
return 0;
if (value == 1)
return 0;
return log2Up((value + 1) / 2) + 1;
}
struct CUBLASWrapper {
......@@ -220,7 +222,7 @@ struct CUBLASWrapper {
CUBLASWrapper() {
checkCUBLAS(cublasCreate(&handle));
}
CUBLASWrapper(CUBLASWrapper &&) = delete;
CUBLASWrapper(CUBLASWrapper &&) = delete;
CUBLASWrapper(const CUBLASWrapper &&) = delete;
~CUBLASWrapper() {
if (handle) {
......@@ -236,6 +238,6 @@ inline std::shared_ptr<CUBLASWrapper> getCUBLAS() {
return result;
}
result = std::make_shared<CUBLASWrapper>();
inst = result;
inst = result;
return result;
}
\ No newline at end of file
}
......@@ -9,8 +9,8 @@ public:
ctxs.insert(this);
}
DebugContext(const DebugContext &) = delete;
DebugContext(DebugContext &&) = delete;
DebugContext(DebugContext &&) = delete;
~DebugContext() {
ctxs.erase(this);
}
......@@ -19,4 +19,3 @@ public:
static inline thread_local std::set<DebugContext *> ctxs;
};
......@@ -22,20 +22,20 @@ Tensor from_torch(at::Tensor input) {
}
static const std::map<at::ScalarType, Tensor::ScalarType> mapType = {
{ at::ScalarType::Char, Tensor::INT8 },
{ at::ScalarType::Byte, Tensor::INT8 },
{ at::ScalarType::Int, Tensor::INT32 },
{ at::ScalarType::Long, Tensor::INT64 },
{ at::ScalarType::Float, Tensor::FP32 },
{ at::ScalarType::Half, Tensor::FP16 },
{ at::ScalarType::BFloat16, Tensor::BF16 },
{ at::ScalarType::Short, Tensor::INT16 },
{ at::ScalarType::Float8_e4m3fn, Tensor::FP8_E4M3 },
{ at::ScalarType::Float8_e5m2, Tensor::FP8_E5M2 },
{at::ScalarType::Char, Tensor::INT8},
{at::ScalarType::Byte, Tensor::INT8},
{at::ScalarType::Int, Tensor::INT32},
{at::ScalarType::Long, Tensor::INT64},
{at::ScalarType::Float, Tensor::FP32},
{at::ScalarType::Half, Tensor::FP16},
{at::ScalarType::BFloat16, Tensor::BF16},
{at::ScalarType::Short, Tensor::INT16},
{at::ScalarType::Float8_e4m3fn, Tensor::FP8_E4M3},
{at::ScalarType::Float8_e5m2, Tensor::FP8_E5M2},
};
result.scalarType = mapType.at(input.scalar_type());
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
......@@ -51,15 +51,15 @@ at::Tensor to_torch(Tensor input) {
}
static const std::map<Tensor::ScalarType, at::ScalarType> mapType = {
{ Tensor::INT8, at::ScalarType::Byte },
{ Tensor::INT32, at::ScalarType::Int },
{ Tensor::INT64, at::ScalarType::Long },
{ Tensor::FP32, at::ScalarType::Float },
{ Tensor::FP16, at::ScalarType::Half },
{ Tensor::BF16, at::ScalarType::BFloat16 },
{ Tensor::INT16, at::ScalarType::Short },
{ Tensor::FP8_E4M3, at::ScalarType::Float8_e4m3fn },
{ Tensor::FP8_E5M2, at::ScalarType::Float8_e5m2 },
{Tensor::INT8, at::ScalarType::Byte},
{Tensor::INT32, at::ScalarType::Int},
{Tensor::INT64, at::ScalarType::Long},
{Tensor::FP32, at::ScalarType::Float},
{Tensor::FP16, at::ScalarType::Half},
{Tensor::BF16, at::ScalarType::BFloat16},
{Tensor::INT16, at::ScalarType::Short},
{Tensor::FP8_E4M3, at::ScalarType::Float8_e4m3fn},
{Tensor::FP8_E5M2, at::ScalarType::Float8_e5m2},
};
c10::TensorOptions opts(mapType.at(input.scalar_type()));
......@@ -82,4 +82,4 @@ TorchOpContext::TorchOpContext() {
TorchOpContext::~TorchOpContext() {
assert(stackCUDAStreams.top() == at::cuda::getCurrentCUDAStream().stream());
stackCUDAStreams.pop();
}
\ No newline at end of file
}
......@@ -8,15 +8,16 @@
class BufferTorchTensor : public Buffer {
public:
BufferTorchTensor(at::Tensor tensor) : tensor(std::move(tensor)) {
this->size = this->tensor.numel() * this->tensor.itemsize();
this->ptr = this->tensor.data_ptr();
this->size = this->tensor.numel() * this->tensor.itemsize();
this->ptr = this->tensor.data_ptr();
this->device.type = this->tensor.is_cuda() ? Device::CUDA : Device::CPU;
this->device.idx = this->tensor.get_device();
this->device.idx = this->tensor.get_device();
}
virtual bool isAsyncBuffer() override {
// TODO: figure out how torch manages memory
return this->device.type == Device::CUDA;
}
private:
at::Tensor tensor;
};
......@@ -25,7 +26,7 @@ class TorchOpContext {
public:
TorchOpContext();
TorchOpContext(const TorchOpContext &) = delete;
TorchOpContext(TorchOpContext &&) = delete;
TorchOpContext(TorchOpContext &&) = delete;
~TorchOpContext();
};
......@@ -48,4 +49,4 @@ public:
private:
std::map<std::string, at::Tensor> storage;
};
\ No newline at end of file
};
......@@ -3,91 +3,84 @@
#include "dispatch_utils.h"
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"activation_kernel", \
[&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
});
void silu_and_mul(
Tensor& out, // [..., d]
Tensor& input) // [..., 2 * d]
void silu_and_mul(Tensor &out, // [..., d]
Tensor &input) // [..., 2 * d]
{
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
// dispatchFloat(input.scalar_type(), [&]<typename scalar_t>() {
// vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
// out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
// });
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] {
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
});
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
// dispatchFloat(input.scalar_type(), [&]<typename scalar_t>() {
// vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
// out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
// });
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] {
vllm::silu_and_mul_kernel<scalar_t>
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
});
}
void invoke_dequant_silu_and_mul_quant(
Tensor &out, // [..., d]
Tensor &input, // [..., 2 * d]
const float scale_gate, const float scale_up, const float scale_out) {
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float, false><<<grid, block, 0, stream>>>(
out.data_ptr<int8_t>(), input.data_ptr<int32_t>(), d, scale_gate,
scale_up, scale_out);
void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
Tensor &input, // [..., 2 * d]
const float scale_gate,
const float scale_up,
const float scale_out) {
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float, false><<<grid, block, 0, stream>>>(
out.data_ptr<int8_t>(), input.data_ptr<int32_t>(), d, scale_gate, scale_up, scale_out);
}
void invoke_dequant_silu_and_mul_quant(
Tensor &out, // [..., d]
Tensor &input, // [..., 2 * d]
const float scale_gate, const float scale_up,
Tensor &scale_out, // [num_tokens]
Tensor &tmp // [..., d]
void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
Tensor &input, // [..., 2 * d]
const float scale_gate,
const float scale_up,
Tensor &scale_out, // [num_tokens]
Tensor &tmp // [..., d]
) {
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float*, true><<<grid, block, 0, stream>>>(
out.data_ptr<int8_t>(), input.data_ptr<int32_t>(),
d, scale_gate, scale_up, scale_out.data_ptr<float>(), tmp.data_ptr<float>());
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float *, true><<<grid, block, 0, stream>>>(out.data_ptr<int8_t>(),
input.data_ptr<int32_t>(),
d,
scale_gate,
scale_up,
scale_out.data_ptr<float>(),
tmp.data_ptr<float>());
}
void silu(
Tensor& out, // [..., d]
Tensor& input) // [..., d]
void silu(Tensor &out, // [..., d]
Tensor &input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::silu);
LAUNCH_ACTIVATION_KERNEL(vllm::silu);
}
void gelu_new(
Tensor& out, // [..., d]
Tensor& input) // [..., d]
void gelu_new(Tensor &out, // [..., d]
Tensor &input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}
void gelu_fast(
Tensor& out, // [..., d]
Tensor& input) // [..., d]
void gelu_fast(Tensor &out, // [..., d]
Tensor &input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
\ No newline at end of file
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
......@@ -3,9 +3,8 @@
#include "common.h"
#include "Tensor.h"
void silu(
Tensor& out, // [..., d]
Tensor& input);
void silu(Tensor &out, // [..., d]
Tensor &input);
void silu_and_mul(Tensor &out, // [..., d]
Tensor &input); // [..., 2 * d]
......@@ -25,5 +24,5 @@ void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
const float scale_gate,
const float scale_up,
Tensor &scale_out, // [num_tokens]
Tensor &tmp // [num_tokens, d]
);
\ No newline at end of file
Tensor &tmp // [num_tokens, d]
);
......@@ -3,116 +3,104 @@
namespace vllm {
template <typename T> __device__ __forceinline__ T silu(const T &x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
template<typename T>
__device__ __forceinline__ T silu(const T &x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}
template<typename scalar_t>
__global__ void silu_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2 * d]
const int d) {
__global__ void silu_and_mul_kernel(scalar_t *__restrict__ out, // [..., d]
const scalar_t *__restrict__ input, // [..., 2 * d]
const int d) {
const int token_idx = blockIdx.x;
const int64_t token_idx_d = token_idx * int64_t(d);
const int64_t token_idx_2d = token_idx_d * 2;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx_2d + idx]);
const scalar_t y = __ldg(&input[token_idx_2d + d + idx]);
out[token_idx_d + idx] = silu(x) * y;
}
const int token_idx = blockIdx.x;
const int64_t token_idx_d = token_idx * int64_t(d);
const int64_t token_idx_2d = token_idx_d * 2;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx_2d + idx]);
const scalar_t y = __ldg(&input[token_idx_2d + d + idx]);
out[token_idx_d + idx] = silu(x) * y;
}
}
// dequant int32 input, apply silu and mul, then per token quant to int8
template <typename scale_type, bool use_per_token_quant>
__global__ void dequant_silu_and_mul_quant_kernel(
int8_t *__restrict__ out, // [..., d]
const int32_t *__restrict__ input, // [..., 2 * d]
const int d, const float scale_gate, const float scale_up,
scale_type scale_out, // [num_tokens]
float *__restrict__ tmp = nullptr // [num_tokens, d]
template<typename scale_type, bool use_per_token_quant>
__global__ void dequant_silu_and_mul_quant_kernel(int8_t *__restrict__ out, // [..., d]
const int32_t *__restrict__ input, // [..., 2 * d]
const int d,
const float scale_gate,
const float scale_up,
scale_type scale_out, // [num_tokens]
float *__restrict__ tmp = nullptr // [num_tokens, d]
) {
const int token_idx = blockIdx.x;
if constexpr (use_per_token_quant) {
float amax_val = 0.0f;
const float zero = 0.0f;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x =
(float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y =
(float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
float t = silu(x) * y;
tmp[token_idx * d + idx] = t;
t = t > zero ? t : -t;
if (t > amax_val)
amax_val = t;
}
__shared__ float s_amax;
const float block_amax_val = blockReduceMax(amax_val);
if (threadIdx.x == 0) {
s_amax = block_amax_val;
scale_out[token_idx] = block_amax_val / 127.0f;
const int token_idx = blockIdx.x;
if constexpr (use_per_token_quant) {
float amax_val = 0.0f;
const float zero = 0.0f;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
float t = silu(x) * y;
tmp[token_idx * d + idx] = t;
t = t > zero ? t : -t;
if (t > amax_val)
amax_val = t;
}
__shared__ float s_amax;
const float block_amax_val = blockReduceMax(amax_val);
if (threadIdx.x == 0) {
s_amax = block_amax_val;
scale_out[token_idx] = block_amax_val / 127.0f;
}
__syncthreads();
float tmp_scale = 127.0f / s_amax;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
out[token_idx * d + idx] = float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]);
}
} else {
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out);
}
}
__syncthreads();
float tmp_scale = 127.0f / s_amax;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
out[token_idx * d + idx] =
float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]);
}
} else {
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x =
(float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y =
(float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out);
}
}
}
} // namespace vllm
namespace vllm {
// Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t &)>
__global__ void activation_kernel(scalar_t *__restrict__ out, // [..., d]
const scalar_t *__restrict__ input, // [..., d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
} // namespace vllm
namespace vllm {
template <typename T> __device__ __forceinline__ T gelu_new_kernel(const T &x) {
const float x3 = (float)(x * x * x);
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T)0.5) * x * (((T)1.0) + t);
template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T &x) {
const float x3 = (float)(x * x * x);
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T)0.5) * x * (((T)1.0) + t);
}
template <typename T>
template<typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T &x) {
const float f = (float)x;
const T t =
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
const float f = (float)x;
const T t = (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
}
} // namespace vllm
/*
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
Modified from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
......@@ -13,16 +14,15 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
#include <cuda_fp16.h>
#include <cstdint>
__forceinline__ __device__
void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
__forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
......@@ -75,25 +75,26 @@ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
}
__forceinline__ __device__
void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
__forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
// dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result);
// *reinterpret_cast<__nv_bfloat162 *>(&result->x) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->x));
// *reinterpret_cast<__nv_bfloat162 *>(&result->y) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->y));
// *reinterpret_cast<__nv_bfloat162 *>(&result->z) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
// *reinterpret_cast<__nv_bfloat162 *>(&result->w) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->w));
// *reinterpret_cast<__nv_bfloat162 *>(&result->x) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2
// *>(&result->x)); *reinterpret_cast<__nv_bfloat162 *>(&result->y) =
// cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->y)); *reinterpret_cast<__nv_bfloat162
// *>(&result->z) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
// *reinterpret_cast<__nv_bfloat162 *>(&result->w) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2
// *>(&result->w));
// return;
// uint4 result;
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
// Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
......@@ -127,4 +128,4 @@ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_67
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS));
}
\ No newline at end of file
}
......@@ -2,14 +2,13 @@
#include <cuda_bf16.h>
#include "semaphore.h"
#include "gemm_awq.h"
//#include "../../../nunchaku/csrc/quantization/dequantize.cuh"
// #include "../../../nunchaku/csrc/quantization/dequantize.cuh"
#include "dequantize.cuh"
#include <stdio.h>
#include "../dispatch_utils.h"
//#include "../../../nunchaku/csrc/utils.cuh"
// #include "../../../nunchaku/csrc/utils.cuh"
#include "../utils.cuh"
#include <cuda_pipeline_primitives.h>
#define kInterleave 4
......@@ -29,1141 +28,1342 @@
#define L2_CACHEHINT(size)
#endif
#define KERNEL_LAUNCH_CODE \
int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \
Tensor _semaphores = Tensor::empty({num_mn_tiles}, Tensor::INT32, _in_feats.device()); \
auto semaphores = reinterpret_cast<int *>(_semaphores.data_ptr<int>()); \
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * sizeof(f16_t); \
if (kSmemByteSize >= 99 * 1024) \
{ \
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); \
return _out_feats; \
} \
int j_factors1 = num_out_channels / CTA_N / 1; \
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = gemm_w4a16_T1<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>; \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>( \
in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);
template <int N>
__inline__ __host__ __device__ int get_log_tile(int n)
{
if (N >= 8 && n >= 6)
return 3;
else if (N >= 4 && n >= 3)
return 2;
else if (N >= 2 && n >= 2)
return 1;
else
return 0;
#define KERNEL_LAUNCH_CODE \
int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \
Tensor _semaphores = Tensor::empty({num_mn_tiles}, Tensor::INT32, _in_feats.device()); \
auto semaphores = reinterpret_cast<int *>(_semaphores.data_ptr<int>()); \
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int SCALES_SMEM_SIZE = \
(G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
constexpr int kSmemByteSize = \
(CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * \
sizeof(f16_t); \
if (kSmemByteSize >= 99 * 1024) { \
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); \
return _out_feats; \
} \
int j_factors1 = num_out_channels / CTA_N / 1; \
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = gemm_w4a16_T1<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>; \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>( \
in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);
template<int N>
__inline__ __host__ __device__ int get_log_tile(int n) {
if (N >= 8 && n >= 6)
return 3;
else if (N >= 4 && n >= 3)
return 2;
else if (N >= 2 && n >= 2)
return 1;
else
return 0;
}
__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile)
{
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) {
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
}
template <int SLICES, int NUM_WARPS_MN>
__device__ void sync_slice(int slice_id)
{
if constexpr (SLICES == 1)
{
__syncthreads();
}
else
{
constexpr int SLICE_GROUP = (SLICES + 7) / 8;
constexpr uint32_t num_threads = NUM_WARPS_MN * WARP_SIZE;
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
}
template<int SLICES, int NUM_WARPS_MN>
__device__ void sync_slice(int slice_id) {
if constexpr (SLICES == 1) {
__syncthreads();
} else {
constexpr int SLICE_GROUP = (SLICES + 7) / 8;
constexpr uint32_t num_threads = NUM_WARPS_MN * WARP_SIZE;
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
}
}
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr)
{
uint32_t smem_int_ptr;
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) {
uint32_t smem_int_ptr;
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_int_ptr)
: "l"(ptr));
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_int_ptr)
: "l"(ptr));
return smem_int_ptr;
return smem_int_ptr;
}
template <typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr)
{
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
}
template <typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr)
{
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
}
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask)
{
const int cp_size = 16;
asm volatile("{"
" .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
"}" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) {
const int cp_size = 16;
asm volatile("{"
" .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
"}" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
}
template <typename f16_t>
template<typename f16_t>
__device__ __inline__ void mma_m16n8k16(float *C_warp, f16_t *A_shared_warp, f16_t *B_shared_warp);
template <>
__device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp, half *B_shared_warp)
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3]));
template<>
__device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp, half *B_shared_warp) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]),
"r"(((unsigned *)A_shared_warp)[1]),
"r"(((unsigned *)A_shared_warp)[2]),
"r"(((unsigned *)A_shared_warp)[3]),
"r"(((unsigned *)B_shared_warp)[0]),
"r"(((unsigned *)B_shared_warp)[1]),
"f"(((float *)C_warp)[0]),
"f"(((float *)C_warp)[1]),
"f"(((float *)C_warp)[2]),
"f"(((float *)C_warp)[3]));
}
template <>
__device__ __inline__ void mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp)
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3]));
template<>
__device__ __inline__ void
mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]),
"r"(((unsigned *)A_shared_warp)[1]),
"r"(((unsigned *)A_shared_warp)[2]),
"r"(((unsigned *)A_shared_warp)[3]),
"r"(((unsigned *)B_shared_warp)[0]),
"r"(((unsigned *)B_shared_warp)[1]),
"f"(((float *)C_warp)[0]),
"f"(((float *)C_warp)[1]),
"f"(((float *)C_warp)[2]),
"f"(((float *)C_warp)[3]));
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A(f16_t *src, f16_t *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A(f16_t *src,
f16_t *dst,
int global_nrows,
int global_ncols,
int cta_offset_m,
int cta_offset_n,
int cta_offset_k,
int global_iter_k,
int shared_iter_k,
bool mask) {
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K + cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
}
else
{
if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
uint4 *src_ptr =
(uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K +
cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols +
// threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row)
// * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) *
// PACK_SIZE);
if constexpr (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
} else {
if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B(f16_t *src, f16_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B(f16_t *src,
f16_t *dst,
int global_ncols,
int cta_offset_m,
int cta_offset_n,
int cta_offset_k,
int global_iter_k,
int shared_iter_k,
bool mask) {
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col = (threadIdx.x % threads_per_row);
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
}
else
{
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col = (threadIdx.x % threads_per_row);
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols +
ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k);
if constexpr (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
} else {
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales(f16_t *src, f16_t *dst, f16_t *src_z, f16_t *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int LD_AMOUNT = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G;
constexpr int threads_needed = LD_AMOUNT / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = LD_AMOUNT / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = (cta_offset_k + global_iter_k * CTA_K) / G;
void *dst_ptr = (void *)(dst + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE);
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
}
else
{
if (local_mask)
{
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales(f16_t *src,
f16_t *dst,
f16_t *src_z,
f16_t *dst_z,
int global_ncols,
int cta_offset_m,
int cta_offset_n,
int cta_offset_k,
int global_iter_k,
int shared_iter_k,
bool mask) {
constexpr int LD_AMOUNT = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G;
constexpr int threads_needed = LD_AMOUNT / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = LD_AMOUNT / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = (cta_offset_k + global_iter_k * CTA_K) / G;
void *dst_ptr =
(void *)(dst + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr =
(uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols +
(threadIdx.x % threads_per_row) * PACK_SIZE);
void *dst_ptr_z =
(void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr_z =
(uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols +
(threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
} else {
if (local_mask) {
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void share_to_reg_one_stage_A(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void
share_to_reg_one_stage_A(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) {
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8 + warp_offset_k;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8 + warp_offset_k;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B(f16_t *src, f16_t *src_scales, f16_t *src_zeros, f16_t *dst, f16_t *dst_fp16, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1)
{
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix)
{
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B(f16_t *src,
f16_t *src_scales,
f16_t *src_zeros,
f16_t *dst,
f16_t *dst_fp16,
int warp_offset_m,
int warp_offset_n,
int warp_offset_k,
int k_0_1) {
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix) {
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled + warp_offset_k);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
void *addr_ptr =
(void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol +
k_0_1 * 16 + r * kSmemCol + c_swizzled + warp_offset_k);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
}
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
f16_t scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f16_t zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f162_t scale2 = f162f162(scale);
f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
f16_t scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) +
threadIdx.x / 4];
f16_t zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) +
threadIdx.x / 4];
f162_t scale2 = f162f162(scale);
f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8),
reinterpret_cast<uint4 *>(loaded));
#pragma unroll
for (int i = 0; i < 4; i++)
{
loaded[i] = __hfma2(loaded[i], scale2, zero2);
for (int i = 0; i < 4; i++) {
loaded[i] = __hfma2(loaded[i], scale2, zero2);
}
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
}
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G, int SPLITK>
__global__ void gemm_w4a16_T1(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_t *__restrict__ scales, f16_t *__restrict__ zeros, f16_t *__restrict__ C, int *__restrict__ semaphores, int M, int N, int K)
{
template<typename f16_t,
int CTA_M,
int CTA_N,
int CTA_K,
int WARP_M,
int WARP_N,
int WARP_K,
int STAGES,
int G,
int SPLITK>
__global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
f16_t *__restrict__ B,
f16_t *__restrict__ scales,
f16_t *__restrict__ zeros,
f16_t *__restrict__ C,
int *__restrict__ semaphores,
int M,
int N,
int K) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch();
return;
trap_unsupported_arch();
return;
#endif
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
constexpr int SLICES = CTA_K / WARP_K;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
constexpr int kSmemSizeScales = CTA_N * STAGES / scales_load_interval * scales_per_load;
constexpr int kSmemSizeZeros = CTA_N * STAGES / scales_load_interval * scales_per_load;
extern __shared__ half mem_shared[];
f16_t *A_shared = reinterpret_cast<f16_t *>(mem_shared);
f16_t *B_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA);
f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB);
f16_t *zeros_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
float *C_shared = reinterpret_cast<float *>(mem_shared);
f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int cta_offset_k = blockIdx_z * (K / SPLITK);
int warp_mn = threadIdx.y % NUM_WARPS_MN;
int slice_id = threadIdx.y / NUM_WARPS_MN;
int warp_offset_n = (warp_mn % (CTA_N / WARP_N)) * WARP_N;
int warp_offset_m = (warp_mn / (CTA_N / WARP_N)) * WARP_M;
int warp_offset_k = slice_id * WARP_K;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K / SPLITK;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
constexpr int SLICES = CTA_K / WARP_K;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
constexpr int kSmemSizeScales = CTA_N * STAGES / scales_load_interval * scales_per_load;
constexpr int kSmemSizeZeros = CTA_N * STAGES / scales_load_interval * scales_per_load;
extern __shared__ half mem_shared[];
f16_t *A_shared = reinterpret_cast<f16_t *>(mem_shared);
f16_t *B_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA);
f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB);
f16_t *zeros_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
float *C_shared = reinterpret_cast<float *>(mem_shared);
f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int cta_offset_k = blockIdx_z * (K / SPLITK);
int warp_mn = threadIdx.y % NUM_WARPS_MN;
int slice_id = threadIdx.y / NUM_WARPS_MN;
int warp_offset_n = (warp_mn % (CTA_N / WARP_N)) * WARP_N;
int warp_offset_m = (warp_mn / (CTA_N / WARP_N)) * WARP_M;
int warp_offset_k = slice_id * WARP_K;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K / SPLITK;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
#pragma unroll
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)
{
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true);
global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
zeros, zeros_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
N, cta_offset_m, cta_offset_n, cta_offset_k,
k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) {
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A,
A_shared +
k_0_0_ld * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
0,
true);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B,
B_shared +
k_0_0_ld * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
0,
true);
global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales,
scales_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
zeros,
zeros_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
N,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
0,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
{
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared,
scales_shared,
zeros_shared,
B_shared_warp_tmp_[0],
B_shared_warp_[0],
warp_offset_m,
warp_offset_n,
warp_offset_k,
0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) {
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
#pragma unroll
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
{
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0)
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
}
else
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
}
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
{
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
{
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1)
{
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2)
{
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2)
{
__syncthreads();
}
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
zeros, zeros_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
N, cta_offset_m, cta_offset_n, cta_offset_k,
k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
{
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) {
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage =
scales_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
zeros_shared_this_compute_stage =
zeros_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
A_shared_this_compute_stage,
A_shared_warp_[(iter_k + 1) % 2],
warp_offset_m,
warp_offset_n,
warp_offset_k,
(iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0) {
if (compute_stage % 2 == 1) {
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
warp_offset_k,
(iter_k + 1) % SHARED_K_ITERS);
} else {
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
warp_offset_k,
(iter_k + 1) % SHARED_K_ITERS);
}
} else {
if (compute_stage % 2 == 1) {
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
warp_offset_k,
(iter_k + 1) % SHARED_K_ITERS);
} else {
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
warp_offset_k,
(iter_k + 1) % SHARED_K_ITERS);
}
}
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) {
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) {
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8,
A_shared_warp + i_0_3 * 8,
B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4,
A_shared_warp + i_0_3 * 8,
B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1) {
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
A,
A_shared + ld_stage * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
B,
B_shared + ld_stage * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2) {
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) {
__syncthreads();
}
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
A,
A_shared + ld_stage * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k + 1,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
B,
B_shared + ld_stage * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k + 1,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales,
scales_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
zeros,
zeros_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
N,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1) {
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
}
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
if constexpr (SLICES > 1)
{
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
if constexpr (SLICES > 1) {
#pragma unroll
for (int z = 0; z < SLICES; ++z)
{
if (slice_id == z)
{
for (int z = 0; z < SLICES; ++z) {
if (slice_id == z) {
#pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
#pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
#pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
{
if (z > 0)
{
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
}
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
};
}
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
if (z > 0) {
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] +=
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n +
ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
}
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 +
(local_id % 2) + (threadIdx.x % 4) * 2] =
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
};
}
}
}
__syncthreads();
}
}
__syncthreads();
}
if (slice_id == 0)
{
if (slice_id == 0) {
#pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
#pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
#pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
{
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
};
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] =
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 +
(local_id % 2) + (threadIdx.x % 4) * 2];
};
}
}
}
}
}
}
if (slice_id == 0)
{
Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x);
if (slice_id == 0) {
Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x);
if constexpr (SPLITK > 1)
{
semaphore.fetch();
}
if constexpr (SPLITK > 1) {
semaphore.fetch();
}
if (blockIdx_z != 0)
{
semaphore.wait(blockIdx_z);
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M)
{
f162_t *existing_psum_ptr = reinterpret_cast<f162_t *>(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2);
*existing_psum_ptr = __hadd2(
*existing_psum_ptr,
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(
C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id)));
if (blockIdx_z != 0) {
semaphore.wait(blockIdx_z);
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M +
((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M) {
f162_t *existing_psum_ptr = reinterpret_cast<f162_t *>(
C + write_row * N + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + (local_id / 4) * 8 +
(local_id % 2) + (threadIdx.x % 4) * 2);
*existing_psum_ptr =
__hadd2(*existing_psum_ptr,
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(
C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id)));
}
};
}
}
};
}
}
}
else
{
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M)
{
*reinterpret_cast<f162_t *>(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
} else {
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M +
((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M) {
*reinterpret_cast<f162_t *>(C + write_row * N + cta_offset_n + warp_offset_n +
ax1_0_1 * 16 + (local_id / 4) * 8 + (local_id % 2) +
(threadIdx.x % 4) * 2) =
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
}
};
}
}
};
}
}
}
if constexpr (SPLITK > 1)
{
if constexpr (SPLITK > 1) {
int lock = 0;
if (SPLITK == blockIdx_z + 1)
{
int lock = 0;
if (SPLITK == blockIdx_z + 1) {
lock = 0;
}
else
{
lock = blockIdx_z + 1;
}
semaphore.release(lock);
lock = 0;
} else {
lock = blockIdx_z + 1;
}
semaphore.release(lock);
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A_T2(f16_t *src, f16_t *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A_T2(f16_t *src,
f16_t *dst,
int global_nrows,
int global_ncols,
int cta_offset_m,
int cta_offset_n,
int global_iter_k,
int shared_iter_k,
bool mask) {
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
}
else
{
if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
uint4 *src_ptr =
(uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE +
global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n *
// global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols +
// (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K
// + (threadIdx.x % threads_per_row) * PACK_SIZE);
if constexpr (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
} else {
if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B_T2(f16_t *src, f16_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B_T2(f16_t *src,
f16_t *dst,
int global_ncols,
int cta_offset_m,
int cta_offset_n,
int global_iter_k,
int shared_iter_k,
bool mask) {
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col = (threadIdx.x % threads_per_row);
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
}
else
{
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col = (threadIdx.x % threads_per_row);
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols +
ld_row * global_ncols + ld_col * PACK_SIZE);
if constexpr (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
} else {
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales_T2(f16_t *src, f16_t *dst, f16_t *src_z, f16_t *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = global_iter_k * CTA_K / G;
void *dst_ptr = (void *)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
}
else
{
if (local_mask)
{
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales_T2(f16_t *src,
f16_t *dst,
f16_t *src_z,
f16_t *dst_z,
int global_ncols,
int cta_offset_m,
int cta_offset_n,
int global_iter_k,
int shared_iter_k,
bool mask) {
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = global_iter_k * CTA_K / G;
void *dst_ptr = (void *)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr_z =
(uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
} else {
if (local_mask) {
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void share_to_reg_one_stage_A_T2(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int k_0_1)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void
share_to_reg_one_stage_A_T2(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int k_0_1) {
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B_T2(f16_t *src, f16_t *src_scales, f16_t *src_zeros, f16_t *dst, f16_t *dst_fp16, int warp_offset_m, int warp_offset_n, int k_0_1)
{
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix)
{
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B_T2(f16_t *src,
f16_t *src_scales,
f16_t *src_zeros,
f16_t *dst,
f16_t *dst_fp16,
int warp_offset_m,
int warp_offset_n,
int k_0_1) {
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix) {
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
void *addr_ptr =
(void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol +
k_0_1 * 16 + r * kSmemCol + c_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
}
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
f16_t scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f16_t zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f162_t scale2 = f162f162(scale);
f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
f16_t scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f16_t zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f162_t scale2 = f162f162(scale);
f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8),
reinterpret_cast<uint4 *>(loaded));
#pragma unroll
for (int i = 0; i < 4; i++)
{
loaded[i] = __hfma2(loaded[i], scale2, zero2);
for (int i = 0; i < 4; i++) {
loaded[i] = __hfma2(loaded[i], scale2, zero2);
}
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
}
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
__global__ void gemm_w4a16_T2(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_t *__restrict__ scales, f16_t *__restrict__ zeros, f16_t *__restrict__ C, int M, int N, int K)
{
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
__global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
f16_t *__restrict__ B,
f16_t *__restrict__ scales,
f16_t *__restrict__ zeros,
f16_t *__restrict__ C,
int M,
int N,
int K) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch();
return;
trap_unsupported_arch();
return;
#endif
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int kSmemSizeScales = CTA_N * STAGES / 2;
constexpr int kSmemSizeZeros = CTA_N * STAGES / 2;
constexpr int scales_load_interval = G / CTA_K;
extern __shared__ half mem_shared[];
f16_t *A_shared = reinterpret_cast<f16_t *>(mem_shared);
f16_t *B_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA);
f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB);
f16_t *zeros_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int warp_offset_m = (threadIdx.y % (CTA_M / WARP_M)) * WARP_M;
int warp_offset_n = (threadIdx.y / (CTA_M / WARP_M)) * WARP_N;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE; i++)
C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int kSmemSizeScales = CTA_N * STAGES / 2;
constexpr int kSmemSizeZeros = CTA_N * STAGES / 2;
constexpr int scales_load_interval = G / CTA_K;
extern __shared__ half mem_shared[];
f16_t *A_shared = reinterpret_cast<f16_t *>(mem_shared);
f16_t *B_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA);
f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB);
f16_t *zeros_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int warp_offset_m = (threadIdx.y % (CTA_M / WARP_M)) * WARP_M;
int warp_offset_n = (threadIdx.y / (CTA_M / WARP_M)) * WARP_N;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE; i++)
C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
#pragma unroll
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)
{
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_scales_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
zeros, zeros_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
{
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
{
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval) * CTA_N;
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0)
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
}
else
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
}
__syncthreads();
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
{
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
{
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1)
{
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2)
{
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2)
{
__syncthreads();
}
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) {
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_scales_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (ld_stage / scales_load_interval) * CTA_N,
zeros, zeros_shared + (ld_stage / scales_load_interval) * CTA_N,
N, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
scales,
scales_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
zeros,
zeros_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
N,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
0,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
{
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared,
scales_shared,
zeros_shared,
B_shared_warp_tmp_[0],
B_shared_warp_[0],
warp_offset_m,
warp_offset_n,
0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) {
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) {
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval) * CTA_N;
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
A_shared_this_compute_stage,
A_shared_warp_[(iter_k + 1) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0) {
if (compute_stage % 2 == 1) {
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
} else {
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
}
} else {
if (compute_stage % 2 == 1) {
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
} else {
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
}
}
__syncthreads();
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) {
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) {
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8,
A_shared_warp + i_0_3 * 8,
B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4,
A_shared_warp + i_0_3 * 8,
B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1) {
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
A,
A_shared + ld_stage * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
B,
B_shared + ld_stage * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2) {
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) {
__syncthreads();
}
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
A,
A_shared + ld_stage * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k + 1,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
B,
B_shared + ld_stage * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k + 1,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales,
scales_shared + (ld_stage / scales_load_interval) * CTA_N,
zeros,
zeros_shared + (ld_stage / scales_load_interval) * CTA_N,
N,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1) {
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
}
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M)
{
*reinterpret_cast<f162_t *>(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
int write_row =
cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M) {
*reinterpret_cast<f162_t *>(C + write_row * N + cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
}
};
}
};
}
}
}
Tensor awq_gemm_forward_cuda(
Tensor _in_feats,
Tensor _kernel,
Tensor _scales,
Tensor _zeros)
{
auto output_shape = _in_feats.shape.dataExtent;
output_shape.back() = _kernel.size(0) * kInterleave;
int num_in_feats = _in_feats.numel() / _in_feats.size(-1);
int num_in_channels = _in_feats.size(-1);
auto options =
Tensor::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
auto options_int =
Tensor::TensorOptions().dtype(Tensor::INT32).device(_in_feats.device());
Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
int num_out_feats = _out_feats.numel() / _out_feats.size(-1);
int num_out_channels = _out_feats.size(-1);
if (_in_feats.scalar_type() == Tensor::FP16)
{
using f16_t = half;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
auto scales = reinterpret_cast<f16_t *>(_scales.data_ptr());
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr());
auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());
if (num_out_feats <= 32)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 2;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 64)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 128)
{
constexpr int G = 128;
constexpr int CTA_M = 32;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 192)
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(f16_t);
if (kSmemByteSize >= 99 * 1024)
{
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
return _out_feats;
}
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
}
else if (_in_feats.scalar_type() == Tensor::BF16)
{
using f16_t = __nv_bfloat16;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
auto scales = reinterpret_cast<f16_t *>(_scales.data_ptr());
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr());
auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());
if (num_out_feats <= 32)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 2;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 64)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 128)
{
constexpr int G = 128;
constexpr int CTA_M = 32;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 192)
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(f16_t);
if (kSmemByteSize >= 99 * 1024)
{
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
return _out_feats;
}
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, Tensor _zeros) {
auto output_shape = _in_feats.shape.dataExtent;
output_shape.back() = _kernel.size(0) * kInterleave;
int num_in_feats = _in_feats.numel() / _in_feats.size(-1);
int num_in_channels = _in_feats.size(-1);
auto options = Tensor::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
auto options_int = Tensor::TensorOptions().dtype(Tensor::INT32).device(_in_feats.device());
Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
int num_out_feats = _out_feats.numel() / _out_feats.size(-1);
int num_out_channels = _out_feats.size(-1);
if (_in_feats.scalar_type() == Tensor::FP16) {
using f16_t = half;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
auto scales = reinterpret_cast<f16_t *>(_scales.data_ptr());
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr());
auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());
if (num_out_feats <= 32) {
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 2;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 64) {
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 128) {
constexpr int G = 128;
constexpr int CTA_M = 32;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 192) {
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
} else {
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int kSmemByteSize =
(CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES *
sizeof(f16_t);
if (kSmemByteSize >= 99 * 1024) {
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
return _out_feats;
}
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
} else if (_in_feats.scalar_type() == Tensor::BF16) {
using f16_t = __nv_bfloat16;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
auto scales = reinterpret_cast<f16_t *>(_scales.data_ptr());
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr());
auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());
if (num_out_feats <= 32) {
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 2;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 64) {
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 128) {
constexpr int G = 128;
constexpr int CTA_M = 32;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 192) {
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
} else {
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int kSmemByteSize =
(CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES *
sizeof(f16_t);
if (kSmemByteSize >= 99 * 1024) {
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
return _out_feats;
}
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
} else {
throw std::runtime_error("Unsupported input type");
}
}
else
{
throw std::runtime_error("Unsupported input type");
}
return _out_feats;
}
\ No newline at end of file
return _out_feats;
}
......@@ -3,9 +3,4 @@
#include "common.h"
#include "Tensor.h"
Tensor awq_gemm_forward_cuda(
Tensor _in_feats,
Tensor _kernel,
Tensor _scales,
Tensor _zeros);
Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, Tensor _zeros);
/*
* Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* Modified from NVIDIA
* [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -39,106 +40,98 @@
#define MEM_ACCESS_SIZE 128
// Reduce sum within the warp using the tree reduction algorithm.
template <typename float_t, int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(float_t* psum, float (*out_smem)[Num * 4])
{
// kInterleave = 4
float fpsum[Num];
#pragma unroll
for (int i = 0; i < Num; ++i)
{
fpsum[i] = static_cast<float>(psum[i]);
}
#pragma unroll
for (int i = 0; i < Num; ++i)
{
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
}
__syncthreads();
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
if (lane == 0 || lane == 2 || lane == 4 || lane == 6)
{
#pragma unroll
for (int i = 0; i < Num; ++i)
{
out_smem[warp][i * 4 + lane / 2] = fpsum[i];
}
}
__syncthreads();
template<typename float_t, int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(float_t *psum, float (*out_smem)[Num * 4]) {
// kInterleave = 4
float fpsum[Num];
#pragma unroll
for (int i = 0; i < Num; ++i) {
fpsum[i] = static_cast<float>(psum[i]);
}
#pragma unroll
for (int i = 0; i < Num; ++i) {
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
}
__syncthreads();
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
if (lane == 0 || lane == 2 || lane == 4 || lane == 6) {
#pragma unroll
for (int i = 0; i < Num; ++i) {
out_smem[warp][i * 4 + lane / 2] = fpsum[i];
}
}
__syncthreads();
};
__device__ __forceinline__ int make_divisible(int c, int divisor){
return (c + divisor - 1) / divisor;
__device__ __forceinline__ int make_divisible(int c, int divisor) {
return (c + divisor - 1) / divisor;
}
template<typename half_t>
__device__ __forceinline__
packed_as<half_t, 2>::type half2half2(half_t x);
template<typename half_t>
__device__ __forceinline__ packed_as<half_t, 2>::type half2half2(half_t x);
template<>
__device__ __forceinline__
packed_as<half, 2>::type half2half2<half>(half x) {
template<>
__device__ __forceinline__ packed_as<half, 2>::type half2half2<half>(half x) {
return __half2half2(x);
}
template<>
__device__ __forceinline__
packed_as<__nv_bfloat16, 2>::type half2half2<__nv_bfloat16>(__nv_bfloat16 x) {
template<>
__device__ __forceinline__ packed_as<__nv_bfloat16, 2>::type half2half2<__nv_bfloat16>(__nv_bfloat16 x) {
return __bfloat162bfloat162(x);
}
template<typename T>
__device__ __forceinline__
float2 half22float2(T val);
__device__ __forceinline__ float2 half22float2(T val);
template<>
__device__ __forceinline__
float2 half22float2<half2>(half2 val) {
__device__ __forceinline__ float2 half22float2<half2>(half2 val) {
return __half22float2(val);
}
template<>
__device__ __forceinline__
float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
__device__ __forceinline__ float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
return __bfloat1622float2(val);
}
template <typename half_t, int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(
const half_t* inputs, const uint32_t* weight, const half_t* scales, const half_t* zeros, half_t* outputs,
const int IC, const int OC)
{
template<typename half_t, int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(const half_t *inputs,
const uint32_t *weight,
const half_t *scales,
const half_t *zeros,
half_t *outputs,
const int IC,
const int OC) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
if constexpr(std::is_same_v<half_t, __nv_bfloat16>) {
if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
trap_unsupported_arch();
return;
}
#endif
using half2_t = typename packed_as<half_t, 2>::type;
using accum_t = float;
using half2_t = typename packed_as<half_t, 2>::type;
using accum_t = float;
using accum2_t = typename packed_as<accum_t, 2>::type;
const int kStride = 64;
const int kElemsPerThread = MEM_ACCESS_SIZE / 4;
const int kStride = 64;
const int kElemsPerThread = MEM_ACCESS_SIZE / 4;
const int kThreadsNumPerTile = kStride / kElemsPerThread;
// assert(MEM_ACCESS_SIZE == 128);
// static constexpr int kShuffleSize = 32;
static constexpr int kShuffleBasicTile = 2;
static constexpr int kShuffleContinous = 4;
static constexpr int kShuffleStrided = 4;
static constexpr int kShuffleStrided = 4;
constexpr int Num = NPerBlock * Batch;
constexpr int Num = NPerBlock * Batch;
constexpr int kInterleave = 4;
alignas(16) half_t local_inputs[kElemsPerThread];
alignas(16) uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
alignas(16) half_t half_weight_buffer[kElemsPerThread];
alignas(16) half_t half_weight_buffer[kElemsPerThread];
alignas(16) half_t dequantized_weight[kElemsPerThread * NPerBlock];
alignas(16) half_t local_scale[NPerBlock];
alignas(16) half_t local_scaled_zeros[NPerBlock];
......@@ -146,7 +139,7 @@ __global__ void gemv_kernel(
accum_t psum[Num];
for (int i = 0; i < Num; ++i)
psum[i] = static_cast<accum_t>(0.f);
// extern __shared__ uint8_t shmem[];
// float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
......@@ -154,80 +147,67 @@ __global__ void gemv_kernel(
const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride
+ (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride +
(threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
const int group_offset = act_k_offset / GroupSize;
// TODO: use make_divisible
const uint32_t* blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
const half_t* scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;
const half_t* zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
const half_t* inputs_ptr = inputs + act_k_offset;
const uint32_t *blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
const half_t *scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;
const half_t *zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
const half_t *inputs_ptr = inputs + act_k_offset;
const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;
const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;
const int scale_forward_step = act_forward_step / GroupSize * OC;
// Main loop iteration, each block completes the outputs for several OCs
for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread)
{
// Load qweight, scales and scaled_zeros
#pragma unroll
for (int idx = 0; idx < NPerBlock; ++idx)
{
for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread) {
// Load qweight, scales and scaled_zeros
#pragma unroll
for (int idx = 0; idx < NPerBlock; ++idx) {
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
*((float4*)(local_qweights)) =
*((float4*)(blk_weight_ptr + (idx * kInterleave * IC + kk)/ PACK_FACTOR));
local_scale[idx] = *(scale_ptr + idx * kInterleave);
local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
// Map int4 qweight to fp format
#pragma unroll
for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i)
{
*((float4 *)(local_qweights)) = *((float4 *)(blk_weight_ptr + (idx * kInterleave * IC + kk) / PACK_FACTOR));
local_scale[idx] = *(scale_ptr + idx * kInterleave);
local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
// Map int4 qweight to fp format
#pragma unroll
for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i) {
// Converts 32 bits (8 x int4) to 8 fp16
dequantize_s4_to_fp16x2(*reinterpret_cast<half2_t *>(local_qweights + i), reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
dequantize_s4_to_fp16x2(*reinterpret_cast<half2_t *>(local_qweights + i),
reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
}
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
for (int i = 0; i < kShuffleContinous; ++i)
{
#pragma unroll
for (int j = 0; j < kShuffleStrided; ++j)
{
half2_t w =
*reinterpret_cast<half2_t*>(
half_weight_buffer + (i + j * kShuffleContinous)* kShuffleBasicTile
);
w = __hfma2(w, half2half2(local_scale[idx]), half2half2(local_scaled_zeros[idx]));
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0)
* NPerBlock + idx]
= w.x;
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1)
* NPerBlock + idx]
= w.y;
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
for (int i = 0; i < kShuffleContinous; ++i) {
#pragma unroll
for (int j = 0; j < kShuffleStrided; ++j) {
half2_t w = *reinterpret_cast<half2_t *>(half_weight_buffer +
(i + j * kShuffleContinous) * kShuffleBasicTile);
w = __hfma2(w, half2half2(local_scale[idx]), half2half2(local_scaled_zeros[idx]));
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0) * NPerBlock + idx] = w.x;
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1) * NPerBlock + idx] = w.y;
}
}
}
#pragma unroll
for (int batch_idx = 0; batch_idx < Batch; ++batch_idx)
{
const half_t* local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
for (int idx = 0; idx < kElemsPerThread / 8; ++idx)
{
}
}
#pragma unroll
for (int batch_idx = 0; batch_idx < Batch; ++batch_idx) {
const half_t *local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
for (int idx = 0; idx < kElemsPerThread / 8; ++idx) {
// load activation, 8 halves (128 bits) / step.
*((float4*)(local_inputs + idx * 8)) = *((float4*)(local_inputs_ptr + idx * 8));
*((float4 *)(local_inputs + idx * 8)) = *((float4 *)(local_inputs_ptr + idx * 8));
}
// Perform the MACs
#pragma unroll
for (int x = 0; x < NPerBlock / 2; ++x)
{
#pragma unroll
for (int y = 0; y < kElemsPerThread; ++y)
{
accum2_t prod = cuda_cast<accum2_t>(__hmul2(*reinterpret_cast<half2_t*>(dequantized_weight + y * NPerBlock + x * 2), half2half2(local_inputs[y])));
*reinterpret_cast<accum2_t*>(psum + batch_idx * NPerBlock + x * 2)
= prod + *reinterpret_cast<accum2_t*>(psum + batch_idx * NPerBlock + x * 2);
// Perform the MACs
#pragma unroll
for (int x = 0; x < NPerBlock / 2; ++x) {
#pragma unroll
for (int y = 0; y < kElemsPerThread; ++y) {
accum2_t prod = cuda_cast<accum2_t>(
__hmul2(*reinterpret_cast<half2_t *>(dequantized_weight + y * NPerBlock + x * 2),
half2half2(local_inputs[y])));
*reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2) =
prod + *reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2);
// *reinterpret_cast<half2_t*>(psum + batch_idx * NPerBlock + x * 2)
// = __hfma2(*reinterpret_cast<half2_t*>(dequantized_weight + y * NPerBlock + x * 2),
// half2half2(local_inputs[y]),
......@@ -243,13 +223,11 @@ __global__ void gemv_kernel(
warp_reduce<accum_t, Num, WARP_SIZE>(psum, out_smem);
// Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize)
{
for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize) {
int batch_idx = i / (NPerBlock * kInterleave);
int oc_idx = i % (NPerBlock * kInterleave);
float acc = 0.f;
for (int j = 0; j < BlockSize / WARP_SIZE; ++j)
{
int oc_idx = i % (NPerBlock * kInterleave);
float acc = 0.f;
for (int j = 0; j < BlockSize / WARP_SIZE; ++j) {
acc += out_smem[j][i];
}
outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast<half_t>(acc);
......@@ -271,32 +249,24 @@ Returns:
out_feats: tensor of shape [B, OC];
*/
Tensor gemv_awq(
Tensor _in_feats,
Tensor _kernel,
Tensor _scaling_factors,
Tensor _zeros,
int m,
int n,
int k,
int group_size)
{
Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int m, int n, int k, int group_size) {
return dispatchFloat16(_scaling_factors.scalar_type(), [&]<typename half_t>() {
assert(isTypeMatch<half_t>(_in_feats.dtype()));
auto output_shape = _in_feats.shape.dataExtent;
auto output_shape = _in_feats.shape.dataExtent;
output_shape.back() = n;
auto in_feats = reinterpret_cast<half_t*>(_in_feats.data_ptr<half_t>());
auto kernel = reinterpret_cast<uint32_t*>(_kernel.data_ptr());
auto zeros = reinterpret_cast<half_t*>(_zeros.data_ptr<half_t>());
auto scaling_factors = reinterpret_cast<half_t*>(_scaling_factors.data_ptr<half_t>());
auto in_feats = reinterpret_cast<half_t *>(_in_feats.data_ptr<half_t>());
auto kernel = reinterpret_cast<uint32_t *>(_kernel.data_ptr());
auto zeros = reinterpret_cast<half_t *>(_zeros.data_ptr<half_t>());
auto scaling_factors = reinterpret_cast<half_t *>(_scaling_factors.data_ptr<half_t>());
Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
half_t * out_feats = reinterpret_cast<half_t *>(_out_feats.data_ptr());
static constexpr int N_PER_BLOCK = 2;
half_t *out_feats = reinterpret_cast<half_t *>(_out_feats.data_ptr());
static constexpr int N_PER_BLOCK = 2;
static constexpr int K_INTERLEAVE = 4;
static constexpr int BLOCK_SIZE = 256;
static constexpr int BLOCK_SIZE = 256;
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE);
......@@ -312,9 +282,9 @@ Tensor gemv_awq(
return;
}
if constexpr (M > 0) {
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE>
<<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
checkCUDA(cudaGetLastError());
}
});
......
......@@ -3,12 +3,5 @@
#include "common.h"
#include "Tensor.h"
Tensor gemv_awq(
Tensor _in_feats,
Tensor _kernel,
Tensor _scaling_factors,
Tensor _zeros,
int m,
int n,
int k,
int group_size);
\ No newline at end of file
Tensor
gemv_awq(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int m, int n, int k, int group_size);
......@@ -41,65 +41,54 @@
/////////////////////////////////////////////////////////////////////////////////////////////////
/// CTA-wide semaphore for inter-CTA synchronization.
class Semaphore
{
class Semaphore {
public:
int *lock;
bool wait_thread;
int state;
int *lock;
bool wait_thread;
int state;
public:
/// Implements a semaphore to wait for a flag to reach a given value
__host__ __device__ Semaphore(int *lock_, int thread_id) : lock(lock_),
wait_thread(thread_id < 0 || thread_id == 0),
state(-1)
{
}
/// Implements a semaphore to wait for a flag to reach a given value
__host__ __device__ Semaphore(int *lock_, int thread_id)
: lock(lock_), wait_thread(thread_id < 0 || thread_id == 0), state(-1) {}
/// Permit fetching the synchronization mechanism early
__device__ void fetch()
{
if (wait_thread)
{
/// Permit fetching the synchronization mechanism early
__device__ void fetch() {
if (wait_thread) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif
}
}
}
/// Gets the internal state
__device__ int get_state() const
{
return state;
}
/// Waits until the semaphore is equal to the given value
__device__ void wait(int status = 0)
{
while (__syncthreads_and(state != status))
{
fetch();
/// Gets the internal state
__device__ int get_state() const {
return state;
}
__syncthreads();
}
/// Waits until the semaphore is equal to the given value
__device__ void wait(int status = 0) {
while (__syncthreads_and(state != status)) {
fetch();
}
__syncthreads();
}
/// Updates the lock with the given result
__device__ void release(int status = 0)
{
__syncthreads();
/// Updates the lock with the given result
__device__ void release(int status = 0) {
__syncthreads();
if (wait_thread)
{
if (wait_thread) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#endif
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -16,4 +16,4 @@ inline void dispatchF16(Tensor::ScalarType type, F &&func) {
} else {
assert(false);
}
}
\ No newline at end of file
}
......@@ -53,17 +53,16 @@ inline auto dispatch(Tensor::ScalarType scalarType, F &&func) {
}
#pragma nv_diagnostic push
// warning #445-D: template parameter "scalar_t" is not used in declaring the parameter types of function template "lambda []()->auto::operator auto (*)()"
#pragma nv_diag_suppress 445
// warning #445-D: template parameter "scalar_t" is not used in declaring the parameter types of function template
// "lambda []()->auto::operator auto (*)()"
#pragma nv_diag_suppress 445
template<typename T>
inline bool isTypeMatch(Tensor::ScalarType scalarType) {
return dispatch(scalarType, []<typename scalar_t>() {
return std::is_same_v<scalar_t, T>;
});
return dispatch(scalarType, []<typename scalar_t>() { return std::is_same_v<scalar_t, T>; });
}
#pragma nv_diagnostic pop
template<typename F, int ...N>
template<typename F, int... N>
inline auto dispatchVal(int val, std::integer_sequence<int, N...>, F &&func) {
auto call = [&]<int i>() {
if (val == i) {
......@@ -82,5 +81,4 @@ inline auto dispatchBool(bool val, F &&func) {
}
}
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) dispatchFloat(TYPE, [&]<typename scalar_t>() { __VA_ARGS__(); });
\ No newline at end of file
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) dispatchFloat(TYPE, [&]<typename scalar_t>() { __VA_ARGS__(); });
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment