Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
...@@ -3,14 +3,13 @@ ...@@ -3,14 +3,13 @@
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include <mio/mmap.hpp> #include <mio/mmap.hpp>
using json = nlohmann::json; using json = nlohmann::json;
using spdlog::fmt_lib::format; using spdlog::fmt_lib::format;
class SafeTensors::MMapImpl { class SafeTensors::MMapImpl {
public: public:
virtual ~MMapImpl() {} virtual ~MMapImpl() {}
virtual size_t size() = 0; virtual size_t size() = 0;
virtual const char *data() = 0; virtual const char *data() = 0;
}; };
...@@ -55,7 +54,7 @@ private: ...@@ -55,7 +54,7 @@ private:
std::unique_ptr<Buffer> buffer; std::unique_ptr<Buffer> buffer;
}; };
#ifdef __linux__ #ifdef __linux__
#include <unistd.h> #include <unistd.h>
#include <fcntl.h> #include <fcntl.h>
...@@ -97,7 +96,7 @@ private: ...@@ -97,7 +96,7 @@ private:
void *ptr; void *ptr;
}; };
#else #else
class SafeTensors::MMapImplPrivate : public SafeTensors::MMapImpl { class SafeTensors::MMapImplPrivate : public SafeTensors::MMapImpl {
public: public:
...@@ -117,33 +116,34 @@ public: ...@@ -117,33 +116,34 @@ public:
SafeTensors::SafeTensors(const std::string &filename) { SafeTensors::SafeTensors(const std::string &filename) {
this->hostRegistered = false; this->hostRegistered = false;
this->memoryPinned = false; this->memoryPinned = false;
auto methodPrivate = [&]() { auto methodPrivate = [&]() {
this->mapped = std::make_unique<MMapImplPrivate>(filename); this->mapped = std::make_unique<MMapImplPrivate>(filename);
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable)); checkCUDA(
cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable));
this->hostRegistered = true; this->hostRegistered = true;
this->memoryPinned = true; this->memoryPinned = true;
}; };
auto methodMio = [&]() { auto methodMio = [&]() {
this->mapped = std::make_unique<MMapImplMio>(filename); this->mapped = std::make_unique<MMapImplMio>(filename);
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable | cudaHostRegisterReadOnly)); checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()),
this->mapped->size(),
cudaHostRegisterPortable | cudaHostRegisterReadOnly));
this->hostRegistered = true; this->hostRegistered = true;
this->memoryPinned = true; this->memoryPinned = true;
}; };
auto methodRead = [&]() { auto methodRead = [&]() {
this->mapped = std::make_unique<MMapImplRead>(filename, true); this->mapped = std::make_unique<MMapImplRead>(filename, true);
this->memoryPinned = true; this->memoryPinned = true;
}; };
auto methodReadNopin = [&]() { auto methodReadNopin = [&]() { this->mapped = std::make_unique<MMapImplRead>(filename, false); };
this->mapped = std::make_unique<MMapImplRead>(filename, false);
};
const std::map<std::string, std::function<void()>> methods = { const std::map<std::string, std::function<void()>> methods = {
{ "PRIVATE", methodPrivate }, {"PRIVATE", methodPrivate},
{ "MIO", methodMio }, {"MIO", methodMio},
{ "READ", methodRead }, {"READ", methodRead},
{ "READNOPIN", methodReadNopin }, {"READNOPIN", methodReadNopin},
}; };
auto tryMethod = [&](std::string name) { auto tryMethod = [&](std::string name) {
...@@ -168,7 +168,6 @@ SafeTensors::SafeTensors(const std::string &filename) { ...@@ -168,7 +168,6 @@ SafeTensors::SafeTensors(const std::string &filename) {
#else #else
tryMethod("MIO") || tryMethod("READ") || tryMethod("READNOPIN"); tryMethod("MIO") || tryMethod("READ") || tryMethod("READNOPIN");
#endif #endif
} }
if (!this->mapped) { if (!this->mapped) {
...@@ -192,19 +191,20 @@ SafeTensors::~SafeTensors() { ...@@ -192,19 +191,20 @@ SafeTensors::~SafeTensors() {
void SafeTensors::parseHeader() { void SafeTensors::parseHeader() {
static const std::unordered_map<std::string, Tensor::ScalarType> mapDType = { static const std::unordered_map<std::string, Tensor::ScalarType> mapDType = {
{ "BF16", Tensor::BF16 }, {"BF16", Tensor::BF16},
{ "F16", Tensor::FP16 }, {"F16", Tensor::FP16},
{ "F32", Tensor::FP32 }, {"F32", Tensor::FP32},
{ "I8", Tensor::INT8 }, {"I8", Tensor::INT8},
{ "I32", Tensor::INT32 }, {"I32", Tensor::INT32},
{ "I64", Tensor::INT64 }, {"I64", Tensor::INT64},
{ "F8_E4M3", Tensor::FP8_E4M3 }, {"F8_E4M3", Tensor::FP8_E4M3},
{ "F8_E5M2", Tensor::FP8_E5M2 }, {"F8_E5M2", Tensor::FP8_E5M2},
}; };
auto check = [](bool cond, std::source_location location = std::source_location::current()) { auto check = [](bool cond, std::source_location location = std::source_location::current()) {
if (!cond) { if (!cond) {
throw std::runtime_error(format("Safetensors check failed at {}:{}", location.file_name(), location.line())); throw std::runtime_error(
format("Safetensors check failed at {}:{}", location.file_name(), location.line()));
} }
}; };
...@@ -222,8 +222,9 @@ void SafeTensors::parseHeader() { ...@@ -222,8 +222,9 @@ void SafeTensors::parseHeader() {
continue; continue;
} }
auto dtype = mapDType.at(info["dtype"].get<std::string>());; auto dtype = mapDType.at(info["dtype"].get<std::string>());
auto shape = info["shape"].get<std::vector<int>>(); ;
auto shape = info["shape"].get<std::vector<int>>();
auto data_offsets = info["data_offsets"].get<std::vector<uint64_t>>(); auto data_offsets = info["data_offsets"].get<std::vector<uint64_t>>();
check(data_offsets.size() == 2); check(data_offsets.size() == 2);
...@@ -235,8 +236,8 @@ void SafeTensors::parseHeader() { ...@@ -235,8 +236,8 @@ void SafeTensors::parseHeader() {
} }
TensorInfo tinfo; TensorInfo tinfo;
tinfo.type = dtype; tinfo.type = dtype;
tinfo.shape = TensorShape(shape); tinfo.shape = TensorShape(shape);
tinfo.length = data_offsets[1] - data_offsets[0]; tinfo.length = data_offsets[1] - data_offsets[0];
tinfo.offset = 8 + sizeHeader + data_offsets[0]; tinfo.offset = 8 + sizeHeader + data_offsets[0];
...@@ -258,15 +259,15 @@ Tensor SafeTensors::getTensor(const std::string &key) { ...@@ -258,15 +259,15 @@ Tensor SafeTensors::getTensor(const std::string &key) {
std::shared_ptr<BufferMMap> buffer = info.buffer.lock(); std::shared_ptr<BufferMMap> buffer = info.buffer.lock();
if (!buffer) { if (!buffer) {
buffer = std::make_shared<BufferMMap>(const_cast<char *>(this->mapped->data() + info.offset), info.length, shared_from_this()); buffer = std::make_shared<BufferMMap>(
const_cast<char *>(this->mapped->data() + info.offset), info.length, shared_from_this());
info.buffer = buffer; info.buffer = buffer;
} }
Tensor result; Tensor result;
result.shape = info.shape; result.shape = info.shape;
result.scalarType = info.type; result.scalarType = info.type;
result.buffer = buffer; result.buffer = buffer;
return result; return result;
} }
...@@ -6,15 +6,15 @@ ...@@ -6,15 +6,15 @@
class BufferMMap : public Buffer { class BufferMMap : public Buffer {
public: public:
BufferMMap(void *ptr, size_t size, std::shared_ptr<void> parent) : parent(parent) { BufferMMap(void *ptr, size_t size, std::shared_ptr<void> parent) : parent(parent) {
this->size = size; this->size = size;
this->device.type = Device::CPU; this->device.type = Device::CPU;
this->ptr = ptr; this->ptr = ptr;
// auto ret = cudaHostRegister(ptr, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); // auto ret = cudaHostRegister(ptr, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
// if (ret == cudaSuccess) { // if (ret == cudaSuccess) {
// this->registered = true; // this->registered = true;
// } else { // } else {
// log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size, cudaGetErrorString(cudaGetLastError()))); // log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size,
// this->registered = false; // cudaGetErrorString(cudaGetLastError()))); this->registered = false;
// } // }
} }
virtual ~BufferMMap() { virtual ~BufferMMap() {
...@@ -22,6 +22,7 @@ public: ...@@ -22,6 +22,7 @@ public:
// checkCUDA(cudaHostUnregister(ptr)); // checkCUDA(cudaHostUnregister(ptr));
// } // }
} }
public: public:
std::shared_ptr<void> parent; std::shared_ptr<void> parent;
// bool registered; // bool registered;
...@@ -32,7 +33,7 @@ public: ...@@ -32,7 +33,7 @@ public:
SafeTensors(const std::string &filename); SafeTensors(const std::string &filename);
~SafeTensors(); ~SafeTensors();
virtual bool contains(const std::string &key) const override { virtual bool contains(const std::string &key) const override {
return tensors.contains(key); return tensors.contains(key);
} }
virtual Tensor getTensor(const std::string &key) override; virtual Tensor getTensor(const std::string &key) override;
...@@ -57,4 +58,4 @@ private: ...@@ -57,4 +58,4 @@ private:
std::unique_ptr<MMapImpl> mapped; std::unique_ptr<MMapImpl> mapped;
bool hostRegistered, memoryPinned; bool hostRegistered, memoryPinned;
}; };
\ No newline at end of file
...@@ -3,13 +3,10 @@ ...@@ -3,13 +3,10 @@
#include "common.h" #include "common.h"
struct Device { struct Device {
enum Type { enum Type { INVALID_DEVICE_TYPE = 0, CPU, CUDA };
INVALID_DEVICE_TYPE = 0,
CPU, CUDA
};
Type type = INVALID_DEVICE_TYPE; Type type = INVALID_DEVICE_TYPE;
int idx = 0; int idx = 0;
static constexpr Device cpu(int idx = 0) { static constexpr Device cpu(int idx = 0) {
return Device{CPU, idx}; return Device{CPU, idx};
...@@ -23,21 +20,29 @@ struct Device { ...@@ -23,21 +20,29 @@ struct Device {
class Buffer : public std::enable_shared_from_this<Buffer> { class Buffer : public std::enable_shared_from_this<Buffer> {
public: public:
virtual ~Buffer() {} virtual ~Buffer() {}
void *getPtr() { return ptr; } void *getPtr() {
return ptr;
}
template<typename T> template<typename T>
T *getPtr() { return reinterpret_cast<T *>(ptr); } T *getPtr() {
return reinterpret_cast<T *>(ptr);
}
size_t getSize() { return size; } size_t getSize() {
Device getDevice() { return device; } return size;
}
Device getDevice() {
return device;
}
virtual bool isAsyncBuffer() { virtual bool isAsyncBuffer() {
return false; return false;
} }
protected: protected:
template <typename Derived> template<typename Derived>
std::shared_ptr<Derived> shared_from_base() { std::shared_ptr<Derived> shared_from_base() {
return std::static_pointer_cast<Derived>(shared_from_this()); return std::static_pointer_cast<Derived>(shared_from_this());
} }
...@@ -55,9 +60,9 @@ protected: ...@@ -55,9 +60,9 @@ protected:
class BufferMalloc : public Buffer { class BufferMalloc : public Buffer {
public: public:
BufferMalloc(size_t size) { BufferMalloc(size_t size) {
this->size = size; this->size = size;
this->device.type = Device::CPU; this->device.type = Device::CPU;
this->ptr = malloc(size); this->ptr = malloc(size);
} }
virtual ~BufferMalloc() { virtual ~BufferMalloc() {
free(this->ptr); free(this->ptr);
...@@ -67,7 +72,7 @@ public: ...@@ -67,7 +72,7 @@ public:
class BufferHost : public Buffer { class BufferHost : public Buffer {
public: public:
BufferHost(size_t size) { BufferHost(size_t size) {
this->size = size; this->size = size;
this->device.type = Device::CPU; this->device.type = Device::CPU;
checkCUDA(cudaHostAlloc(&this->ptr, size, cudaHostAllocPortable)); checkCUDA(cudaHostAlloc(&this->ptr, size, cudaHostAllocPortable));
} }
...@@ -79,7 +84,7 @@ public: ...@@ -79,7 +84,7 @@ public:
class BufferCUDA : public Buffer { class BufferCUDA : public Buffer {
public: public:
BufferCUDA(size_t size) { BufferCUDA(size_t size) {
this->size = size; this->size = size;
this->device.type = Device::CUDA; this->device.type = Device::CUDA;
// checkCUDA(cudaGetDevice(&this->device.idx)); // checkCUDA(cudaGetDevice(&this->device.idx));
this->device.idx = CUDADeviceContext::getDevice(); this->device.idx = CUDADeviceContext::getDevice();
...@@ -96,7 +101,7 @@ public: ...@@ -96,7 +101,7 @@ public:
} }
checkCUDA(cudaFreeAsync(this->ptr, getCurrentCUDAStream())); checkCUDA(cudaFreeAsync(this->ptr, getCurrentCUDAStream()));
} }
virtual bool isAsyncBuffer() override { virtual bool isAsyncBuffer() override {
return true; return true;
} }
}; };
...@@ -104,7 +109,7 @@ public: ...@@ -104,7 +109,7 @@ public:
class BufferCUDASync : public Buffer { class BufferCUDASync : public Buffer {
public: public:
BufferCUDASync(size_t size) { BufferCUDASync(size_t size) {
this->size = size; this->size = size;
this->device.type = Device::CUDA; this->device.type = Device::CUDA;
checkCUDA(cudaGetDevice(&this->device.idx)); checkCUDA(cudaGetDevice(&this->device.idx));
checkCUDA(cudaMalloc(&this->ptr, size)); checkCUDA(cudaMalloc(&this->ptr, size));
...@@ -118,8 +123,8 @@ class BufferView : public Buffer { ...@@ -118,8 +123,8 @@ class BufferView : public Buffer {
public: public:
BufferView(std::shared_ptr<Buffer> reference, size_t offset, size_t size) : reference(reference) { BufferView(std::shared_ptr<Buffer> reference, size_t offset, size_t size) : reference(reference) {
assert(offset + size <= reference->getSize()); assert(offset + size <= reference->getSize());
this->ptr = (void *)((std::uint8_t *)reference->getPtr() + offset); this->ptr = (void *)((std::uint8_t *)reference->getPtr() + offset);
this->size = size; this->size = size;
this->device = reference->getDevice(); this->device = reference->getDevice();
} }
...@@ -213,23 +218,31 @@ struct TensorShape { ...@@ -213,23 +218,31 @@ struct TensorShape {
} }
}; };
class Tensor { class Tensor {
public: public:
enum ScalarType { enum ScalarType {
INVALID_SCALAR_TYPE, INVALID_SCALAR_TYPE,
INT8, INT16, INT32, INT64, INT8,
FP16, FP32, BF16, INT16,
FP8_E4M3, FP8_E5M2, INT32,
INT64,
FP16,
FP32,
BF16,
FP8_E4M3,
FP8_E5M2,
}; };
struct TensorOptions { struct TensorOptions {
Device device_; Device device_;
ScalarType dtype_; ScalarType dtype_;
Device device() const { return device_; } Device device() const {
ScalarType dtype() const { return dtype_; } return device_;
}
ScalarType dtype() const {
return dtype_;
}
TensorOptions device(Device dev) const { TensorOptions device(Device dev) const {
TensorOptions result(*this); TensorOptions result(*this);
...@@ -244,56 +257,95 @@ public: ...@@ -244,56 +257,95 @@ public:
}; };
static const std::map<ScalarType, size_t> scalarSize; static const std::map<ScalarType, size_t> scalarSize;
public: public:
TensorShape shape; TensorShape shape;
ScalarType scalarType; ScalarType scalarType;
std::shared_ptr<Buffer> buffer; std::shared_ptr<Buffer> buffer;
public: public:
bool valid() const { return shape.dataExtent.size() > 0; } bool valid() const {
int size(int dim) const { return shape[dim]; } return shape.dataExtent.size() > 0;
bool is_contiguous() const { return shape.is_contiguous(); } }
std::vector<int> sizes() const { return shape.dataExtent; } int size(int dim) const {
return shape[dim];
}
bool is_contiguous() const {
return shape.is_contiguous();
}
std::vector<int> sizes() const {
return shape.dataExtent;
}
bool is_cuda() const { return device().type == Device::CUDA; } bool is_cuda() const {
return device().type == Device::CUDA;
}
TensorOptions options() const { return TensorOptions{device(), dtype()}; } TensorOptions options() const {
int get_device() const { return device().idx; } return TensorOptions{device(), dtype()};
}
int get_device() const {
return device().idx;
}
template<typename T> template<typename T>
T *data_ptr() { return reinterpret_cast<T*>(data_ptr()); } T *data_ptr() {
return reinterpret_cast<T *>(data_ptr());
}
template<typename T> template<typename T>
const T *data_ptr() const { return reinterpret_cast<const T*>(data_ptr()); } const T *data_ptr() const {
return reinterpret_cast<const T *>(data_ptr());
const void *data_ptr() const { return buffer->getPtr<char>() + shape.offset * scalar_size(); } }
void *data_ptr() { return buffer->getPtr<char>() + shape.offset * scalar_size(); }
Device device() const { return buffer->getDevice(); } const void *data_ptr() const {
return buffer->getPtr<char>() + shape.offset * scalar_size();
}
void *data_ptr() {
return buffer->getPtr<char>() + shape.offset * scalar_size();
}
ScalarType scalar_type() const { return scalarType; } Device device() const {
ScalarType dtype() const { return scalar_type(); } return buffer->getDevice();
}
ScalarType scalar_type() const {
return scalarType;
}
ScalarType dtype() const {
return scalar_type();
}
size_t stride(int dim) const { return shape.stride(dim); } size_t stride(int dim) const {
return shape.stride(dim);
}
size_t numel() const { return shape.size(); } size_t numel() const {
size_t ndims() const { return shape.ndims(); } return shape.size();
}
size_t ndims() const {
return shape.ndims();
}
size_t dim() const { return ndims(); } size_t dim() const {
return ndims();
}
size_t scalar_size() const { return scalarSize.at(scalarType); } size_t scalar_size() const {
return scalarSize.at(scalarType);
}
Tensor operator[](int idx) const { Tensor operator[](int idx) const {
assert(ndims() > 1); assert(ndims() > 1);
Tensor result; Tensor result;
result.shape = std::vector<int>(this->shape.dataExtent.begin() + 1, this->shape.dataExtent.end()); result.shape = std::vector<int>(this->shape.dataExtent.begin() + 1, this->shape.dataExtent.end());
size_t size = stride(0) * scalar_size(); size_t size = stride(0) * scalar_size();
result.buffer = std::make_shared<BufferView>(this->buffer, idx * size, size); result.buffer = std::make_shared<BufferView>(this->buffer, idx * size, size);
result.scalarType = this->scalarType; result.scalarType = this->scalarType;
return result; return result;
} }
template<typename T> template<typename T>
const T & at(const std::vector<int> &idx) const { const T &at(const std::vector<int> &idx) const {
assert(ndims() == idx.size()); assert(ndims() == idx.size());
int64_t offset = 0; int64_t offset = 0;
for (size_t i = 0; i < ndims(); i++) { for (size_t i = 0; i < ndims(); i++) {
...@@ -304,17 +356,17 @@ public: ...@@ -304,17 +356,17 @@ public:
} }
template<typename T> template<typename T>
T & at(const std::vector<int> &idx) { T &at(const std::vector<int> &idx) {
return const_cast<T &>(const_cast<const Tensor *>(this)->at<T>(idx)); return const_cast<T &>(const_cast<const Tensor *>(this)->at<T>(idx));
} }
Tensor slice(int dim, int from, int to) const { Tensor slice(int dim, int from, int to) const {
assert(from <= to); assert(from <= to);
Tensor result; Tensor result;
result.buffer = this->buffer; result.buffer = this->buffer;
result.scalarType = this->scalarType; result.scalarType = this->scalarType;
result.shape = TensorShape(this->shape.dataExtent); result.shape = TensorShape(this->shape.dataExtent);
result.shape[dim] = to - from; result.shape[dim] = to - from;
result.shape.dataStride.resize(result.shape.ndims()); result.shape.dataStride.resize(result.shape.ndims());
for (int i = 0; i < result.shape.ndims(); i++) { for (int i = 0; i < result.shape.ndims(); i++) {
...@@ -326,7 +378,7 @@ public: ...@@ -326,7 +378,7 @@ public:
} }
Tensor transpose(int dim1, int dim2) const { Tensor transpose(int dim1, int dim2) const {
Tensor result; Tensor result;
result.buffer = this->buffer; result.buffer = this->buffer;
result.scalarType = this->scalarType; result.scalarType = this->scalarType;
result.shape = TensorShape(this->shape.dataExtent); result.shape = TensorShape(this->shape.dataExtent);
...@@ -346,9 +398,9 @@ public: ...@@ -346,9 +398,9 @@ public:
assert(shape.size() == this->shape.size()); assert(shape.size() == this->shape.size());
assert(this->is_contiguous()); assert(this->is_contiguous());
Tensor result; Tensor result;
result.buffer = this->buffer; result.buffer = this->buffer;
result.scalarType = this->scalarType; result.scalarType = this->scalarType;
result.shape = shape; result.shape = shape;
result.shape.offset = this->shape.offset; result.shape.offset = this->shape.offset;
return result; return result;
} }
...@@ -363,7 +415,8 @@ public: ...@@ -363,7 +415,8 @@ public:
Tensor &zero_() { Tensor &zero_() {
assert(this->is_contiguous()); assert(this->is_contiguous());
checkCUDA(cudaMemsetAsync(data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentCUDAStream())); checkCUDA(cudaMemsetAsync(
data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentCUDAStream()));
return *this; return *this;
} }
Tensor &copy_(Tensor other) { Tensor &copy_(Tensor other) {
...@@ -380,23 +433,17 @@ public: ...@@ -380,23 +433,17 @@ public:
} }
if (this->device().type == Device::CPU && other.device().type == Device::CPU) { if (this->device().type == Device::CPU && other.device().type == Device::CPU) {
memcpy( memcpy(data_ptr<char>(), other.data_ptr<char>(), shape.size() * scalar_size());
data_ptr<char>(),
other.data_ptr<char>(),
shape.size() * scalar_size()
);
return *this; return *this;
} }
lockBuffer(this->buffer, getCurrentCUDAStream()); lockBuffer(this->buffer, getCurrentCUDAStream());
lockBuffer(other.buffer, getCurrentCUDAStream()); lockBuffer(other.buffer, getCurrentCUDAStream());
checkCUDA(cudaMemcpyAsync( checkCUDA(cudaMemcpyAsync(data_ptr<char>(),
data_ptr<char>(), other.data_ptr<char>(),
other.data_ptr<char>(), shape.size() * scalar_size(),
shape.size() * scalar_size(), getCopyKind(this->device(), other.device()),
getCopyKind(this->device(), other.device()), getCurrentCUDAStream()));
getCurrentCUDAStream()
));
return *this; return *this;
} }
...@@ -425,14 +472,15 @@ public: ...@@ -425,14 +472,15 @@ public:
assert(false); assert(false);
} }
result.scalarType = scalarType; result.scalarType = scalarType;
result.shape = shape; result.shape = shape;
if (fill) { if (fill) {
if (device.type == Device::CPU) { if (device.type == Device::CPU) {
memset(result.buffer->getPtr(), 0xCC, result.buffer->getSize()); memset(result.buffer->getPtr(), 0xCC, result.buffer->getSize());
} else if (device.type == Device::CUDA) { } else if (device.type == Device::CUDA) {
CUDADeviceContext ctx(device.idx); CUDADeviceContext ctx(device.idx);
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream())); checkCUDA(
cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream()));
} }
} }
...@@ -450,11 +498,12 @@ public: ...@@ -450,11 +498,12 @@ public:
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 1, result.buffer->getSize(), getCurrentCUDAStream())); checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 1, result.buffer->getSize(), getCurrentCUDAStream()));
return result; return result;
} }
static Tensor allocate_view(TensorShape shape, ScalarType scalarType, std::shared_ptr<Buffer> buffer, size_t offset = 0) { static Tensor
allocate_view(TensorShape shape, ScalarType scalarType, std::shared_ptr<Buffer> buffer, size_t offset = 0) {
Tensor result; Tensor result;
result.buffer = std::make_shared<BufferView>(buffer, offset, shape.size() * scalarSize.at(scalarType)); result.buffer = std::make_shared<BufferView>(buffer, offset, shape.size() * scalarSize.at(scalarType));
result.scalarType = scalarType; result.scalarType = scalarType;
result.shape = shape; result.shape = shape;
return result; return result;
} }
...@@ -468,13 +517,16 @@ public: ...@@ -468,13 +517,16 @@ public:
// lockBuffer(this->buffer, getCurrentCUDAStream()); // lockBuffer(this->buffer, getCurrentCUDAStream());
// lockBuffer(result.buffer, getCurrentCUDAStream()); // lockBuffer(result.buffer, getCurrentCUDAStream());
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault, getCurrentCUDAStream())); // checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault,
// if (this->device().type == Device::CPU && device.type == Device::CUDA) { // getCurrentCUDAStream())); if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyHostToDevice, getCurrentCUDAStream())); // checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyHostToDevice, getCurrentCUDAStream()));
// } else if (this->device().type == Device::CUDA && device.type == Device::CPU) { // } else if (this->device().type == Device::CUDA && device.type == Device::CPU) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDeviceToHost, getCurrentCUDAStream())); // checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDeviceToHost, getCurrentCUDAStream()));
// } else { // } else {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault, getCurrentCUDAStream())); // checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDefault, getCurrentCUDAStream()));
// } // }
return result; return result;
} }
...@@ -516,9 +568,10 @@ private: ...@@ -516,9 +568,10 @@ private:
// } // }
static inline std::map<cudaStream_t, std::set<std::shared_ptr<Buffer>>> lockedBuffers; static inline std::map<cudaStream_t, std::set<std::shared_ptr<Buffer>>> lockedBuffers;
public: public:
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU completes // before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU
// completes
static void lockBuffer(std::shared_ptr<Buffer> buffer, cudaStream_t stream) { static void lockBuffer(std::shared_ptr<Buffer> buffer, cudaStream_t stream) {
if (!buffer->isAsyncBuffer()) { if (!buffer->isAsyncBuffer()) {
lockedBuffers[stream].insert(buffer); lockedBuffers[stream].insert(buffer);
...@@ -558,5 +611,5 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = { ...@@ -558,5 +611,5 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
struct TensorsProvider { struct TensorsProvider {
virtual ~TensorsProvider() {} virtual ~TensorsProvider() {}
virtual bool contains(const std::string &key) const = 0; virtual bool contains(const std::string &key) const = 0;
virtual Tensor getTensor(const std::string &key) = 0; virtual Tensor getTensor(const std::string &key) = 0;
}; };
\ No newline at end of file
...@@ -22,13 +22,15 @@ Tensor GELU::forward(Tensor x) { ...@@ -22,13 +22,15 @@ Tensor GELU::forward(Tensor x) {
// return out; // return out;
// } // }
// Tensor SiluAndMulQuant::forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) { // Tensor SiluAndMulQuant::forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor
// quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor out = SiluAndMul::forward(x); // Tensor out = SiluAndMul::forward(x);
// invoke_quant_fuse_sum(quantized_mlp_act_buffer, out, quantized_sum_buffer, quantized_scale_buffer); // invoke_quant_fuse_sum(quantized_mlp_act_buffer, out, quantized_sum_buffer, quantized_scale_buffer);
// return out; // return out;
// } // }
// Tensor SiluAndMulQuant::forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) { // Tensor SiluAndMulQuant::forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer,
// Tensor quantized_sum_buffer) {
// Tensor out = SiluAndMul::forward(x); // Tensor out = SiluAndMul::forward(x);
// invoke_quant(quantized_mlp_act_buffer, out, quantized_scale_buffer, {}); // invoke_quant(quantized_mlp_act_buffer, out, quantized_scale_buffer, {});
// return out; // return out;
......
...@@ -20,7 +20,8 @@ public: ...@@ -20,7 +20,8 @@ public:
// class SiluAndMulQuant { // class SiluAndMulQuant {
// public: // public:
// static Tensor forward(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer, bool act_sum) { // static Tensor forward(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor
// quantized_sum_buffer, bool act_sum) {
// if (act_sum) { // if (act_sum) {
// return forward_with_act_sum(x, quantized_mlp_act_buffer, quantized_scale_buffer, quantized_sum_buffer); // return forward_with_act_sum(x, quantized_mlp_act_buffer, quantized_scale_buffer, quantized_sum_buffer);
// } else { // } else {
...@@ -28,6 +29,7 @@ public: ...@@ -28,6 +29,7 @@ public:
// } // }
// } // }
// private: // private:
// static Tensor forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer); // static Tensor forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer,
// static Tensor forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer); // Tensor quantized_sum_buffer); static Tensor forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor
// }; // quantized_scale_buffer, Tensor quantized_sum_buffer);
\ No newline at end of file // };
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
class CUDAError : public std::runtime_error { class CUDAError : public std::runtime_error {
public: public:
CUDAError(cudaError_t errorCode, std::source_location location) CUDAError(cudaError_t errorCode, std::source_location location)
: std::runtime_error(format(errorCode, location)), errorCode(errorCode), location(location) {} : std::runtime_error(format(errorCode, location)), errorCode(errorCode), location(location) {}
public: public:
...@@ -34,12 +34,13 @@ public: ...@@ -34,12 +34,13 @@ public:
private: private:
static std::string format(cudaError_t errorCode, std::source_location location) { static std::string format(cudaError_t errorCode, std::source_location location) {
return spdlog::fmt_lib::format("CUDA error: {} (at {}:{})", return spdlog::fmt_lib::format(
cudaGetErrorString(errorCode), location.file_name(), location.line()); "CUDA error: {} (at {}:{})", cudaGetErrorString(errorCode), location.file_name(), location.line());
} }
}; };
inline cudaError_t checkCUDA(cudaError_t retValue, const std::source_location location = std::source_location::current()) { inline cudaError_t checkCUDA(cudaError_t retValue,
const std::source_location location = std::source_location::current()) {
if (retValue != cudaSuccess) { if (retValue != cudaSuccess) {
(void)cudaGetLastError(); (void)cudaGetLastError();
throw CUDAError(retValue, location); throw CUDAError(retValue, location);
...@@ -47,10 +48,11 @@ inline cudaError_t checkCUDA(cudaError_t retValue, const std::source_location lo ...@@ -47,10 +48,11 @@ inline cudaError_t checkCUDA(cudaError_t retValue, const std::source_location lo
return retValue; return retValue;
} }
inline cublasStatus_t checkCUBLAS(cublasStatus_t retValue, const std::source_location location = std::source_location::current()) { inline cublasStatus_t checkCUBLAS(cublasStatus_t retValue,
const std::source_location location = std::source_location::current()) {
if (retValue != CUBLAS_STATUS_SUCCESS) { if (retValue != CUBLAS_STATUS_SUCCESS) {
throw std::runtime_error(spdlog::fmt_lib::format("CUBLAS error: {} (at {}:{})", throw std::runtime_error(spdlog::fmt_lib::format(
cublasGetStatusString(retValue), location.file_name(), location.line())); "CUBLAS error: {} (at {}:{})", cublasGetStatusString(retValue), location.file_name(), location.line()));
} }
return retValue; return retValue;
} }
...@@ -71,8 +73,8 @@ struct CUDAStreamContext { ...@@ -71,8 +73,8 @@ struct CUDAStreamContext {
stackCUDAStreams.push(stream); stackCUDAStreams.push(stream);
} }
CUDAStreamContext(const CUDAStreamContext &) = delete; CUDAStreamContext(const CUDAStreamContext &) = delete;
CUDAStreamContext(CUDAStreamContext &&) = delete; CUDAStreamContext(CUDAStreamContext &&) = delete;
~CUDAStreamContext() { ~CUDAStreamContext() {
assert(stackCUDAStreams.top() == stream); assert(stackCUDAStreams.top() == stream);
stackCUDAStreams.pop(); stackCUDAStreams.pop();
...@@ -86,7 +88,7 @@ struct CUDAStreamWrapper { ...@@ -86,7 +88,7 @@ struct CUDAStreamWrapper {
checkCUDA(cudaStreamCreate(&stream)); checkCUDA(cudaStreamCreate(&stream));
} }
CUDAStreamWrapper(const CUDAStreamWrapper &) = delete; CUDAStreamWrapper(const CUDAStreamWrapper &) = delete;
CUDAStreamWrapper(CUDAStreamWrapper &&) = delete; CUDAStreamWrapper(CUDAStreamWrapper &&) = delete;
~CUDAStreamWrapper() { ~CUDAStreamWrapper() {
checkCUDA(cudaStreamDestroy(stream)); checkCUDA(cudaStreamDestroy(stream));
...@@ -100,14 +102,13 @@ struct CUDAEventWrapper { ...@@ -100,14 +102,13 @@ struct CUDAEventWrapper {
checkCUDA(cudaEventCreateWithFlags(&event, flags)); checkCUDA(cudaEventCreateWithFlags(&event, flags));
} }
CUDAEventWrapper(const CUDAEventWrapper &) = delete; CUDAEventWrapper(const CUDAEventWrapper &) = delete;
CUDAEventWrapper(CUDAEventWrapper &&) = delete; CUDAEventWrapper(CUDAEventWrapper &&) = delete;
~CUDAEventWrapper() { ~CUDAEventWrapper() {
checkCUDA(cudaEventDestroy(event)); checkCUDA(cudaEventDestroy(event));
} }
}; };
/** /**
* 1. hold one when entered from external code (set `device` to -1 to avoid device change) * 1. hold one when entered from external code (set `device` to -1 to avoid device change)
* 2. hold one when switching device * 2. hold one when switching device
...@@ -121,7 +122,7 @@ public: ...@@ -121,7 +122,7 @@ public:
// previous context is reset on => external code may be executed, reset // previous context is reset on => external code may be executed, reset
currentDeviceCache = -1; currentDeviceCache = -1;
} }
ctxs.push(this); ctxs.push(this);
lastDevice = getDevice(); lastDevice = getDevice();
if (device >= 0) { if (device >= 0) {
...@@ -134,7 +135,7 @@ public: ...@@ -134,7 +135,7 @@ public:
} }
} }
CUDADeviceContext(const CUDADeviceContext &) = delete; CUDADeviceContext(const CUDADeviceContext &) = delete;
CUDADeviceContext(CUDADeviceContext &&) = delete; CUDADeviceContext(CUDADeviceContext &&) = delete;
~CUDADeviceContext() { ~CUDADeviceContext() {
if (disableCache) { if (disableCache) {
...@@ -148,7 +149,8 @@ public: ...@@ -148,7 +149,8 @@ public:
if (cacheDisabled()) { if (cacheDisabled()) {
// ctxs.empty() => we are about to return to external code, reset cache // ctxs.empty() => we are about to return to external code, reset cache
// otherwise => we are a nested context in a previous context with reset on, we might continue to execute external code, reset // otherwise => we are a nested context in a previous context with reset on, we might continue to execute
// external code, reset
currentDeviceCache = -1; currentDeviceCache = -1;
} }
} }
...@@ -156,7 +158,6 @@ public: ...@@ -156,7 +158,6 @@ public:
const bool disableCache; const bool disableCache;
int lastDevice; int lastDevice;
public: public:
static int getDevice() { static int getDevice() {
int idx = -1; int idx = -1;
...@@ -168,6 +169,7 @@ public: ...@@ -168,6 +169,7 @@ public:
currentDeviceCache = cacheDisabled() ? -1 : idx; currentDeviceCache = cacheDisabled() ? -1 : idx;
return idx; return idx;
} }
private: private:
static void setDevice(int idx) { static void setDevice(int idx) {
// TODO: deal with stream when switching device // TODO: deal with stream when switching device
...@@ -207,11 +209,11 @@ constexpr T ceilDiv(T a, T b) { ...@@ -207,11 +209,11 @@ constexpr T ceilDiv(T a, T b) {
template<typename T> template<typename T>
constexpr int log2Up(T value) { constexpr int log2Up(T value) {
if (value <= 0) if (value <= 0)
return 0; return 0;
if (value == 1) if (value == 1)
return 0; return 0;
return log2Up((value + 1) / 2) + 1; return log2Up((value + 1) / 2) + 1;
} }
struct CUBLASWrapper { struct CUBLASWrapper {
...@@ -220,7 +222,7 @@ struct CUBLASWrapper { ...@@ -220,7 +222,7 @@ struct CUBLASWrapper {
CUBLASWrapper() { CUBLASWrapper() {
checkCUBLAS(cublasCreate(&handle)); checkCUBLAS(cublasCreate(&handle));
} }
CUBLASWrapper(CUBLASWrapper &&) = delete; CUBLASWrapper(CUBLASWrapper &&) = delete;
CUBLASWrapper(const CUBLASWrapper &&) = delete; CUBLASWrapper(const CUBLASWrapper &&) = delete;
~CUBLASWrapper() { ~CUBLASWrapper() {
if (handle) { if (handle) {
...@@ -236,6 +238,6 @@ inline std::shared_ptr<CUBLASWrapper> getCUBLAS() { ...@@ -236,6 +238,6 @@ inline std::shared_ptr<CUBLASWrapper> getCUBLAS() {
return result; return result;
} }
result = std::make_shared<CUBLASWrapper>(); result = std::make_shared<CUBLASWrapper>();
inst = result; inst = result;
return result; return result;
} }
\ No newline at end of file
...@@ -9,8 +9,8 @@ public: ...@@ -9,8 +9,8 @@ public:
ctxs.insert(this); ctxs.insert(this);
} }
DebugContext(const DebugContext &) = delete; DebugContext(const DebugContext &) = delete;
DebugContext(DebugContext &&) = delete; DebugContext(DebugContext &&) = delete;
~DebugContext() { ~DebugContext() {
ctxs.erase(this); ctxs.erase(this);
} }
...@@ -19,4 +19,3 @@ public: ...@@ -19,4 +19,3 @@ public:
static inline thread_local std::set<DebugContext *> ctxs; static inline thread_local std::set<DebugContext *> ctxs;
}; };
...@@ -22,20 +22,20 @@ Tensor from_torch(at::Tensor input) { ...@@ -22,20 +22,20 @@ Tensor from_torch(at::Tensor input) {
} }
static const std::map<at::ScalarType, Tensor::ScalarType> mapType = { static const std::map<at::ScalarType, Tensor::ScalarType> mapType = {
{ at::ScalarType::Char, Tensor::INT8 }, {at::ScalarType::Char, Tensor::INT8},
{ at::ScalarType::Byte, Tensor::INT8 }, {at::ScalarType::Byte, Tensor::INT8},
{ at::ScalarType::Int, Tensor::INT32 }, {at::ScalarType::Int, Tensor::INT32},
{ at::ScalarType::Long, Tensor::INT64 }, {at::ScalarType::Long, Tensor::INT64},
{ at::ScalarType::Float, Tensor::FP32 }, {at::ScalarType::Float, Tensor::FP32},
{ at::ScalarType::Half, Tensor::FP16 }, {at::ScalarType::Half, Tensor::FP16},
{ at::ScalarType::BFloat16, Tensor::BF16 }, {at::ScalarType::BFloat16, Tensor::BF16},
{ at::ScalarType::Short, Tensor::INT16 }, {at::ScalarType::Short, Tensor::INT16},
{ at::ScalarType::Float8_e4m3fn, Tensor::FP8_E4M3 }, {at::ScalarType::Float8_e4m3fn, Tensor::FP8_E4M3},
{ at::ScalarType::Float8_e5m2, Tensor::FP8_E5M2 }, {at::ScalarType::Float8_e5m2, Tensor::FP8_E5M2},
}; };
result.scalarType = mapType.at(input.scalar_type()); result.scalarType = mapType.at(input.scalar_type());
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input)); result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
Tensor::lockBuffer(result.buffer, getCurrentCUDAStream()); Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
...@@ -51,15 +51,15 @@ at::Tensor to_torch(Tensor input) { ...@@ -51,15 +51,15 @@ at::Tensor to_torch(Tensor input) {
} }
static const std::map<Tensor::ScalarType, at::ScalarType> mapType = { static const std::map<Tensor::ScalarType, at::ScalarType> mapType = {
{ Tensor::INT8, at::ScalarType::Byte }, {Tensor::INT8, at::ScalarType::Byte},
{ Tensor::INT32, at::ScalarType::Int }, {Tensor::INT32, at::ScalarType::Int},
{ Tensor::INT64, at::ScalarType::Long }, {Tensor::INT64, at::ScalarType::Long},
{ Tensor::FP32, at::ScalarType::Float }, {Tensor::FP32, at::ScalarType::Float},
{ Tensor::FP16, at::ScalarType::Half }, {Tensor::FP16, at::ScalarType::Half},
{ Tensor::BF16, at::ScalarType::BFloat16 }, {Tensor::BF16, at::ScalarType::BFloat16},
{ Tensor::INT16, at::ScalarType::Short }, {Tensor::INT16, at::ScalarType::Short},
{ Tensor::FP8_E4M3, at::ScalarType::Float8_e4m3fn }, {Tensor::FP8_E4M3, at::ScalarType::Float8_e4m3fn},
{ Tensor::FP8_E5M2, at::ScalarType::Float8_e5m2 }, {Tensor::FP8_E5M2, at::ScalarType::Float8_e5m2},
}; };
c10::TensorOptions opts(mapType.at(input.scalar_type())); c10::TensorOptions opts(mapType.at(input.scalar_type()));
...@@ -82,4 +82,4 @@ TorchOpContext::TorchOpContext() { ...@@ -82,4 +82,4 @@ TorchOpContext::TorchOpContext() {
TorchOpContext::~TorchOpContext() { TorchOpContext::~TorchOpContext() {
assert(stackCUDAStreams.top() == at::cuda::getCurrentCUDAStream().stream()); assert(stackCUDAStreams.top() == at::cuda::getCurrentCUDAStream().stream());
stackCUDAStreams.pop(); stackCUDAStreams.pop();
} }
\ No newline at end of file
...@@ -8,15 +8,16 @@ ...@@ -8,15 +8,16 @@
class BufferTorchTensor : public Buffer { class BufferTorchTensor : public Buffer {
public: public:
BufferTorchTensor(at::Tensor tensor) : tensor(std::move(tensor)) { BufferTorchTensor(at::Tensor tensor) : tensor(std::move(tensor)) {
this->size = this->tensor.numel() * this->tensor.itemsize(); this->size = this->tensor.numel() * this->tensor.itemsize();
this->ptr = this->tensor.data_ptr(); this->ptr = this->tensor.data_ptr();
this->device.type = this->tensor.is_cuda() ? Device::CUDA : Device::CPU; this->device.type = this->tensor.is_cuda() ? Device::CUDA : Device::CPU;
this->device.idx = this->tensor.get_device(); this->device.idx = this->tensor.get_device();
} }
virtual bool isAsyncBuffer() override { virtual bool isAsyncBuffer() override {
// TODO: figure out how torch manages memory // TODO: figure out how torch manages memory
return this->device.type == Device::CUDA; return this->device.type == Device::CUDA;
} }
private: private:
at::Tensor tensor; at::Tensor tensor;
}; };
...@@ -25,7 +26,7 @@ class TorchOpContext { ...@@ -25,7 +26,7 @@ class TorchOpContext {
public: public:
TorchOpContext(); TorchOpContext();
TorchOpContext(const TorchOpContext &) = delete; TorchOpContext(const TorchOpContext &) = delete;
TorchOpContext(TorchOpContext &&) = delete; TorchOpContext(TorchOpContext &&) = delete;
~TorchOpContext(); ~TorchOpContext();
}; };
...@@ -48,4 +49,4 @@ public: ...@@ -48,4 +49,4 @@ public:
private: private:
std::map<std::string, at::Tensor> storage; std::map<std::string, at::Tensor> storage;
}; };
\ No newline at end of file
...@@ -3,91 +3,84 @@ ...@@ -3,91 +3,84 @@
#include "dispatch_utils.h" #include "dispatch_utils.h"
// Launch element-wise activation kernel. // Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \ int d = input.size(-1); \
int num_tokens = input.numel() / d; \ int num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = getCurrentCUDAStream(); \ const cudaStream_t stream = getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
input.scalar_type(), \ vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
"activation_kernel", \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
[&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
}); });
void silu_and_mul( void silu_and_mul(Tensor &out, // [..., d]
Tensor& out, // [..., d] Tensor &input) // [..., 2 * d]
Tensor& input) // [..., 2 * d]
{ {
int64_t num_tokens = input.numel() / input.size(-1); int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(d, 1024)); dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream(); const cudaStream_t stream = getCurrentCUDAStream();
// dispatchFloat(input.scalar_type(), [&]<typename scalar_t>() { // dispatchFloat(input.scalar_type(), [&]<typename scalar_t>() {
// vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>( // vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
// out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); // out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
// }); // });
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] { VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] {
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::silu_and_mul_kernel<scalar_t>
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
}); });
} }
void invoke_dequant_silu_and_mul_quant( void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
Tensor &out, // [..., d] Tensor &input, // [..., 2 * d]
Tensor &input, // [..., 2 * d] const float scale_gate,
const float scale_gate, const float scale_up, const float scale_out) { const float scale_up,
int64_t num_tokens = input.numel() / input.size(-1); const float scale_out) {
int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens); int d = input.size(-1) / 2;
dim3 block(std::min(d, 1024)); dim3 grid(num_tokens);
const cudaStream_t stream = getCurrentCUDAStream(); dim3 block(std::min(d, 1024));
vllm::dequant_silu_and_mul_quant_kernel<float, false><<<grid, block, 0, stream>>>( const cudaStream_t stream = getCurrentCUDAStream();
out.data_ptr<int8_t>(), input.data_ptr<int32_t>(), d, scale_gate, vllm::dequant_silu_and_mul_quant_kernel<float, false><<<grid, block, 0, stream>>>(
scale_up, scale_out); out.data_ptr<int8_t>(), input.data_ptr<int32_t>(), d, scale_gate, scale_up, scale_out);
} }
void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
void invoke_dequant_silu_and_mul_quant( Tensor &input, // [..., 2 * d]
Tensor &out, // [..., d] const float scale_gate,
Tensor &input, // [..., 2 * d] const float scale_up,
const float scale_gate, const float scale_up, Tensor &scale_out, // [num_tokens]
Tensor &scale_out, // [num_tokens] Tensor &tmp // [..., d]
Tensor &tmp // [..., d]
) { ) {
int64_t num_tokens = input.numel() / input.size(-1); int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(d, 1024)); dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream(); const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float*, true><<<grid, block, 0, stream>>>( vllm::dequant_silu_and_mul_quant_kernel<float *, true><<<grid, block, 0, stream>>>(out.data_ptr<int8_t>(),
out.data_ptr<int8_t>(), input.data_ptr<int32_t>(), input.data_ptr<int32_t>(),
d, scale_gate, scale_up, scale_out.data_ptr<float>(), tmp.data_ptr<float>()); d,
scale_gate,
scale_up,
scale_out.data_ptr<float>(),
tmp.data_ptr<float>());
} }
void silu( void silu(Tensor &out, // [..., d]
Tensor& out, // [..., d] Tensor &input) // [..., d]
Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::silu); LAUNCH_ACTIVATION_KERNEL(vllm::silu);
} }
void gelu_new( void gelu_new(Tensor &out, // [..., d]
Tensor& out, // [..., d] Tensor &input) // [..., d]
Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
} }
void gelu_fast(Tensor &out, // [..., d]
void gelu_fast( Tensor &input) // [..., d]
Tensor& out, // [..., d]
Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
} }
\ No newline at end of file
...@@ -3,9 +3,8 @@ ...@@ -3,9 +3,8 @@
#include "common.h" #include "common.h"
#include "Tensor.h" #include "Tensor.h"
void silu( void silu(Tensor &out, // [..., d]
Tensor& out, // [..., d] Tensor &input);
Tensor& input);
void silu_and_mul(Tensor &out, // [..., d] void silu_and_mul(Tensor &out, // [..., d]
Tensor &input); // [..., 2 * d] Tensor &input); // [..., 2 * d]
...@@ -25,5 +24,5 @@ void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d] ...@@ -25,5 +24,5 @@ void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
const float scale_gate, const float scale_gate,
const float scale_up, const float scale_up,
Tensor &scale_out, // [num_tokens] Tensor &scale_out, // [num_tokens]
Tensor &tmp // [num_tokens, d] Tensor &tmp // [num_tokens, d]
); );
\ No newline at end of file
...@@ -3,116 +3,104 @@ ...@@ -3,116 +3,104 @@
namespace vllm { namespace vllm {
template <typename T> __device__ __forceinline__ T silu(const T &x) { template<typename T>
// x * sigmoid(x) __device__ __forceinline__ T silu(const T &x) {
return (T)(((float)x) / (1.0f + expf((float)-x))); // x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
} }
template<typename scalar_t> template<typename scalar_t>
__global__ void silu_and_mul_kernel( __global__ void silu_and_mul_kernel(scalar_t *__restrict__ out, // [..., d]
scalar_t* __restrict__ out, // [..., d] const scalar_t *__restrict__ input, // [..., 2 * d]
const scalar_t* __restrict__ input, // [..., 2 * d] const int d) {
const int d) {
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
const int64_t token_idx_d = token_idx * int64_t(d); const int64_t token_idx_d = token_idx * int64_t(d);
const int64_t token_idx_2d = token_idx_d * 2; const int64_t token_idx_2d = token_idx_d * 2;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx_2d + idx]); const scalar_t x = __ldg(&input[token_idx_2d + idx]);
const scalar_t y = __ldg(&input[token_idx_2d + d + idx]); const scalar_t y = __ldg(&input[token_idx_2d + d + idx]);
out[token_idx_d + idx] = silu(x) * y; out[token_idx_d + idx] = silu(x) * y;
} }
} }
// dequant int32 input, apply silu and mul, then per token quant to int8 // dequant int32 input, apply silu and mul, then per token quant to int8
template <typename scale_type, bool use_per_token_quant> template<typename scale_type, bool use_per_token_quant>
__global__ void dequant_silu_and_mul_quant_kernel( __global__ void dequant_silu_and_mul_quant_kernel(int8_t *__restrict__ out, // [..., d]
int8_t *__restrict__ out, // [..., d] const int32_t *__restrict__ input, // [..., 2 * d]
const int32_t *__restrict__ input, // [..., 2 * d] const int d,
const int d, const float scale_gate, const float scale_up, const float scale_gate,
scale_type scale_out, // [num_tokens] const float scale_up,
float *__restrict__ tmp = nullptr // [num_tokens, d] scale_type scale_out, // [num_tokens]
float *__restrict__ tmp = nullptr // [num_tokens, d]
) { ) {
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
if constexpr (use_per_token_quant) { if constexpr (use_per_token_quant) {
float amax_val = 0.0f; float amax_val = 0.0f;
const float zero = 0.0f; const float zero = 0.0f;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x = const float x = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
(float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate; const float y = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
const float y = float t = silu(x) * y;
(float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up; tmp[token_idx * d + idx] = t;
float t = silu(x) * y; t = t > zero ? t : -t;
tmp[token_idx * d + idx] = t; if (t > amax_val)
t = t > zero ? t : -t; amax_val = t;
if (t > amax_val) }
amax_val = t;
} __shared__ float s_amax;
const float block_amax_val = blockReduceMax(amax_val);
__shared__ float s_amax; if (threadIdx.x == 0) {
const float block_amax_val = blockReduceMax(amax_val); s_amax = block_amax_val;
if (threadIdx.x == 0) { scale_out[token_idx] = block_amax_val / 127.0f;
s_amax = block_amax_val; }
scale_out[token_idx] = block_amax_val / 127.0f; __syncthreads();
float tmp_scale = 127.0f / s_amax;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
out[token_idx * d + idx] = float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]);
}
} else {
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out);
}
} }
__syncthreads();
float tmp_scale = 127.0f / s_amax;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
out[token_idx * d + idx] =
float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]);
}
} else {
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x =
(float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y =
(float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out);
}
}
} }
} // namespace vllm } // namespace vllm
namespace vllm { namespace vllm {
// Element-wise activation kernel template. // Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t &)>
__global__ void activation_kernel( __global__ void activation_kernel(scalar_t *__restrict__ out, // [..., d]
scalar_t* __restrict__ out, // [..., d] const scalar_t *__restrict__ input, // [..., d]
const scalar_t* __restrict__ input, // [..., d] const int d) {
const int d) { const int token_idx = blockIdx.x;
const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * d + idx]);
const scalar_t x = __ldg(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x);
out[token_idx * d + idx] = ACT_FN(x); }
}
} }
} // namespace vllm } // namespace vllm
namespace vllm { namespace vllm {
template <typename T> __device__ __forceinline__ T gelu_new_kernel(const T &x) { template<typename T>
const float x3 = (float)(x * x * x); __device__ __forceinline__ T gelu_new_kernel(const T &x) {
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); const float x3 = (float)(x * x * x);
return ((T)0.5) * x * (((T)1.0) + t); const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T)0.5) * x * (((T)1.0) + t);
} }
template <typename T> template<typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T &x) { __device__ __forceinline__ T gelu_fast_kernel(const T &x) {
const float f = (float)x; const float f = (float)x;
const T t = const T t = (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); return ((T)0.5) * x * (((T)1.0) + t);
return ((T)0.5) * x * (((T)1.0) + t);
} }
} // namespace vllm } // namespace vllm
/* /*
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h Modified from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq, @article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
...@@ -13,16 +14,15 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor ...@@ -13,16 +14,15 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cstdint> #include <cstdint>
__forceinline__ __device__ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
uint32_t *h = reinterpret_cast<uint32_t *>(result); uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source); uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number. // First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f; static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0; static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
...@@ -75,25 +75,26 @@ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) { ...@@ -75,25 +75,26 @@ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
} }
__forceinline__ __device__ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
// dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result); // dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result);
// *reinterpret_cast<__nv_bfloat162 *>(&result->x) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->x)); // *reinterpret_cast<__nv_bfloat162 *>(&result->x) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2
// *reinterpret_cast<__nv_bfloat162 *>(&result->y) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->y)); // *>(&result->x)); *reinterpret_cast<__nv_bfloat162 *>(&result->y) =
// *reinterpret_cast<__nv_bfloat162 *>(&result->z) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->z)); // cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->y)); *reinterpret_cast<__nv_bfloat162
// *reinterpret_cast<__nv_bfloat162 *>(&result->w) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->w)); // *>(&result->z) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
// *reinterpret_cast<__nv_bfloat162 *>(&result->w) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2
// *>(&result->w));
// return; // return;
// uint4 result; // uint4 result;
uint32_t *h = reinterpret_cast<uint32_t *>(result); uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source); uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number. // First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
// Extract elt_01 - (i4s & 0x000f000f) | 0x43004300 // Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
...@@ -127,4 +128,4 @@ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) { ...@@ -127,4 +128,4 @@ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS)); asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_67 // Convert elt_67
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS)); asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS));
} }
\ No newline at end of file
...@@ -2,14 +2,13 @@ ...@@ -2,14 +2,13 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include "semaphore.h" #include "semaphore.h"
#include "gemm_awq.h" #include "gemm_awq.h"
//#include "../../../nunchaku/csrc/quantization/dequantize.cuh" // #include "../../../nunchaku/csrc/quantization/dequantize.cuh"
#include "dequantize.cuh" #include "dequantize.cuh"
#include <stdio.h> #include <stdio.h>
#include "../dispatch_utils.h" #include "../dispatch_utils.h"
//#include "../../../nunchaku/csrc/utils.cuh" // #include "../../../nunchaku/csrc/utils.cuh"
#include "../utils.cuh" #include "../utils.cuh"
#include <cuda_pipeline_primitives.h> #include <cuda_pipeline_primitives.h>
#define kInterleave 4 #define kInterleave 4
...@@ -29,1141 +28,1342 @@ ...@@ -29,1141 +28,1342 @@
#define L2_CACHEHINT(size) #define L2_CACHEHINT(size)
#endif #endif
#define KERNEL_LAUNCH_CODE \ #define KERNEL_LAUNCH_CODE \
int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \ int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \
Tensor _semaphores = Tensor::empty({num_mn_tiles}, Tensor::INT32, _in_feats.device()); \ Tensor _semaphores = Tensor::empty({num_mn_tiles}, Tensor::INT32, _in_feats.device()); \
auto semaphores = reinterpret_cast<int *>(_semaphores.data_ptr<int>()); \ auto semaphores = reinterpret_cast<int *>(_semaphores.data_ptr<int>()); \
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \ constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \ constexpr int SCALES_SMEM_SIZE = \
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * sizeof(f16_t); \ (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
if (kSmemByteSize >= 99 * 1024) \ constexpr int kSmemByteSize = \
{ \ (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * \
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); \ sizeof(f16_t); \
return _out_feats; \ if (kSmemByteSize >= 99 * 1024) { \
} \ printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); \
int j_factors1 = num_out_channels / CTA_N / 1; \ return _out_feats; \
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \ } \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ int j_factors1 = num_out_channels / CTA_N / 1; \
auto kernel_func = gemm_w4a16_T1<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>; \ dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \ dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>( \ auto kernel_func = gemm_w4a16_T1<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>; \
in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels); cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>( \
template <int N> in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);
__inline__ __host__ __device__ int get_log_tile(int n)
{ template<int N>
if (N >= 8 && n >= 6) __inline__ __host__ __device__ int get_log_tile(int n) {
return 3; if (N >= 8 && n >= 6)
else if (N >= 4 && n >= 3) return 3;
return 2; else if (N >= 4 && n >= 3)
else if (N >= 2 && n >= 2) return 2;
return 1; else if (N >= 2 && n >= 2)
else return 1;
return 0; else
return 0;
} }
__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) __inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) {
{ return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
} }
template <int SLICES, int NUM_WARPS_MN> template<int SLICES, int NUM_WARPS_MN>
__device__ void sync_slice(int slice_id) __device__ void sync_slice(int slice_id) {
{ if constexpr (SLICES == 1) {
if constexpr (SLICES == 1) __syncthreads();
{ } else {
__syncthreads(); constexpr int SLICE_GROUP = (SLICES + 7) / 8;
} constexpr uint32_t num_threads = NUM_WARPS_MN * WARP_SIZE;
else const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
{ asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
constexpr int SLICE_GROUP = (SLICES + 7) / 8; }
constexpr uint32_t num_threads = NUM_WARPS_MN * WARP_SIZE;
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
}
} }
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) __inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) {
{ uint32_t smem_int_ptr;
uint32_t smem_int_ptr;
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_int_ptr) : "=r"(smem_int_ptr)
: "l"(ptr)); : "l"(ptr));
return smem_int_ptr; return smem_int_ptr;
} }
template <typename f16_t> template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) __inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
{ static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value, "ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types."); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
asm volatile( "{%0, %1, %2, %3}, [%4];"
"ldmatrix.sync.aligned.m8n8.x4.shared.b16" : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"{%0, %1, %2, %3}, [%4];" "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3]) "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
: "r"(addr)); "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
} }
template <typename f16_t> template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) __inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
{ static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value, "ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types."); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
asm volatile( "{%0, %1, %2, %3}, [%4];"
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"{%0, %1, %2, %3}, [%4];" "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3]) "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
: "r"(addr)); "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
} }
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) {
{ const int cp_size = 16;
const int cp_size = 16; asm volatile("{"
asm volatile("{" " .reg .pred p;"
" .reg .pred p;" " setp.ne.b32 p, %0, 0;"
" setp.ne.b32 p, %0, 0;" " @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;" "}" ::"r"((int)mask),
"}" ::"r"((int)mask), "r"(smem_int_ptr),
"r"(smem_int_ptr), "l"(src),
"l"(src), "n"(cp_size));
"n"(cp_size));
} }
template <typename f16_t> template<typename f16_t>
__device__ __inline__ void mma_m16n8k16(float *C_warp, f16_t *A_shared_warp, f16_t *B_shared_warp); __device__ __inline__ void mma_m16n8k16(float *C_warp, f16_t *A_shared_warp, f16_t *B_shared_warp);
template <> template<>
__device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp, half *B_shared_warp) __device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp, half *B_shared_warp) {
{ asm volatile(
asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3]) : "r"(((unsigned *)A_shared_warp)[0]),
: "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3])); "r"(((unsigned *)A_shared_warp)[1]),
"r"(((unsigned *)A_shared_warp)[2]),
"r"(((unsigned *)A_shared_warp)[3]),
"r"(((unsigned *)B_shared_warp)[0]),
"r"(((unsigned *)B_shared_warp)[1]),
"f"(((float *)C_warp)[0]),
"f"(((float *)C_warp)[1]),
"f"(((float *)C_warp)[2]),
"f"(((float *)C_warp)[3]));
} }
template <> template<>
__device__ __inline__ void mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp) __device__ __inline__ void
{ mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp) {
asm volatile( asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32" "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3]) : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3])); : "r"(((unsigned *)A_shared_warp)[0]),
"r"(((unsigned *)A_shared_warp)[1]),
"r"(((unsigned *)A_shared_warp)[2]),
"r"(((unsigned *)A_shared_warp)[3]),
"r"(((unsigned *)B_shared_warp)[0]),
"r"(((unsigned *)B_shared_warp)[1]),
"f"(((float *)C_warp)[0]),
"f"(((float *)C_warp)[1]),
"f"(((float *)C_warp)[2]),
"f"(((float *)C_warp)[3]));
} }
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES> template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A(f16_t *src, f16_t *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) __device__ __inline__ void global_to_share_one_stage_A(f16_t *src,
{ f16_t *dst,
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS; int global_nrows,
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; int global_ncols,
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used; int cta_offset_m,
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS; int cta_offset_n,
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K; int cta_offset_k,
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; int global_iter_k,
constexpr int threads_per_row = CTA_K / PACK_SIZE; int shared_iter_k,
constexpr int kSmemCol = CTA_K + SMEM_PAD_A; bool mask) {
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
int ld_col = (threadIdx.x % threads_per_row); constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
#pragma unroll #pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
{ int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int global_iter = shared_iter_k * partial_global_iters + _global_iter; int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row); int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE; void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled); uint4 *src_ptr =
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K + cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE); (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K +
if constexpr (STAGES > 1) cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols +
{ // threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row)
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); // * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) *
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows)); // PACK_SIZE);
} if constexpr (STAGES > 1) {
else uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
{ cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
if (local_mask & (ld_row + cta_offset_m < global_nrows)) } else {
*(uint4 *)dst_ptr = *src_ptr; if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
}
} }
}
} }
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES> template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B(f16_t *src, f16_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) __device__ __inline__ void global_to_share_one_stage_B(f16_t *src,
{ f16_t *dst,
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS; int global_ncols,
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; int cta_offset_m,
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used; int cta_offset_n,
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS; int cta_offset_k,
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K; int global_iter_k,
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; int shared_iter_k,
constexpr int threads_per_row = CTA_K / PACK_SIZE; bool mask) {
constexpr int kSmemCol = CTA_K + SMEM_PAD_B; constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
#pragma unroll #pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
{ int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row); int ld_col = (threadIdx.x % threads_per_row);
int ld_col = (threadIdx.x % threads_per_row); int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7; void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE)); uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols +
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k); ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k);
if constexpr (STAGES > 1) if constexpr (STAGES > 1) {
{ uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); cp_async_cg_A(addr, src_ptr, local_mask);
cp_async_cg_A(addr, src_ptr, local_mask); } else {
} if (local_mask)
else *(uint4 *)dst_ptr = *src_ptr;
{ }
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
} }
}
} }
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G> template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales(f16_t *src, f16_t *dst, f16_t *src_z, f16_t *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) __device__ __inline__ void global_to_share_one_stage_scales(f16_t *src,
{ f16_t *dst,
constexpr int LD_AMOUNT = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G; f16_t *src_z,
constexpr int threads_needed = LD_AMOUNT / PACK_SIZE / 1; f16_t *dst_z,
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; int global_ncols,
constexpr int total_global_iters = LD_AMOUNT / PACK_SIZE / threads_used; int cta_offset_m,
constexpr int threads_per_row = CTA_N / PACK_SIZE; int cta_offset_n,
constexpr int kSmemCol = CTA_N; int cta_offset_k,
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); int global_iter_k,
int g_idx = (cta_offset_k + global_iter_k * CTA_K) / G; int shared_iter_k,
bool mask) {
void *dst_ptr = (void *)(dst + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE); constexpr int LD_AMOUNT = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G;
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE); constexpr int threads_needed = LD_AMOUNT / PACK_SIZE / 1;
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE); constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE); constexpr int total_global_iters = LD_AMOUNT / PACK_SIZE / threads_used;
if (STAGES > 1) constexpr int threads_per_row = CTA_N / PACK_SIZE;
{ constexpr int kSmemCol = CTA_N;
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
cp_async_cg_A(addr, src_ptr, local_mask); int g_idx = (cta_offset_k + global_iter_k * CTA_K) / G;
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask); void *dst_ptr =
} (void *)(dst + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
else uint4 *src_ptr =
{ (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols +
if (local_mask) (threadIdx.x % threads_per_row) * PACK_SIZE);
{ void *dst_ptr_z =
*(uint4 *)dst_ptr = *src_ptr; (void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
*(uint4 *)dst_ptr_z = *src_ptr_z; uint4 *src_ptr_z =
(uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols +
(threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
} else {
if (local_mask) {
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
}
} }
}
} }
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters> template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void share_to_reg_one_stage_A(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) __device__ __inline__ void
{ share_to_reg_one_stage_A(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) {
constexpr int kSmemCol = CTA_K + SMEM_PAD_A; constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
{
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16); int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8 + warp_offset_k; int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8 + warp_offset_k;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE; int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled); void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
} }
} }
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G> template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B(f16_t *src, f16_t *src_scales, f16_t *src_zeros, f16_t *dst, f16_t *dst_fp16, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) __device__ __inline__ void share_to_reg_one_stage_B(f16_t *src,
{ f16_t *src_scales,
using f162_t = typename packed_as<f16_t, 2>::type; f16_t *src_zeros,
f16_t *dst,
constexpr int kSmemCol = CTA_K + SMEM_PAD_B; f16_t *dst_fp16,
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8); int warp_offset_m,
int c0 = ((threadIdx.x / 8) % 2) * 8; int warp_offset_n,
int r = r0 / 4; int warp_offset_k,
int c = (r0 % 4) * 16 + c0; int k_0_1) {
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE; using f162_t = typename packed_as<f16_t, 2>::type;
if constexpr (ldmatrix) constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
{ int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix) {
#pragma unroll #pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
{ void *addr_ptr =
void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled + warp_offset_k); (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol +
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); k_0_1 * 16 + r * kSmemCol + c_swizzled + warp_offset_k);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
} }
}
#pragma unroll #pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
{ f16_t scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) +
f16_t scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; threadIdx.x / 4];
f16_t zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; f16_t zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) +
f162_t scale2 = f162f162(scale); threadIdx.x / 4];
f162_t zero2 = f162f162(zero); f162_t scale2 = f162f162(scale);
f162_t loaded[4]; f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8),
reinterpret_cast<uint4 *>(loaded));
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) for (int i = 0; i < 4; i++) {
{ loaded[i] = __hfma2(loaded[i], scale2, zero2);
loaded[i] = __hfma2(loaded[i], scale2, zero2); }
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
} }
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
}
} }
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G, int SPLITK> template<typename f16_t,
__global__ void gemm_w4a16_T1(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_t *__restrict__ scales, f16_t *__restrict__ zeros, f16_t *__restrict__ C, int *__restrict__ semaphores, int M, int N, int K) int CTA_M,
{ int CTA_N,
int CTA_K,
int WARP_M,
int WARP_N,
int WARP_K,
int STAGES,
int G,
int SPLITK>
__global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
f16_t *__restrict__ B,
f16_t *__restrict__ scales,
f16_t *__restrict__ zeros,
f16_t *__restrict__ C,
int *__restrict__ semaphores,
int M,
int N,
int K) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch(); trap_unsupported_arch();
return; return;
#endif #endif
using f162_t = typename packed_as<f16_t, 2>::type; using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N; constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K; constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE; constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE; constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
constexpr int SLICES = CTA_K / WARP_K; constexpr int SLICES = CTA_K / WARP_K;
int num_blocks_n = (N + CTA_N - 1) / CTA_N; int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M; int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0; int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n); int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n); int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N); const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile); int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile); int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile); const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x; blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y; blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE_MN]; float C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A; constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B; constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA; constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB; constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES; constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES; constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1; constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1; constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
constexpr int kSmemSizeScales = CTA_N * STAGES / scales_load_interval * scales_per_load; constexpr int kSmemSizeScales = CTA_N * STAGES / scales_load_interval * scales_per_load;
constexpr int kSmemSizeZeros = CTA_N * STAGES / scales_load_interval * scales_per_load; constexpr int kSmemSizeZeros = CTA_N * STAGES / scales_load_interval * scales_per_load;
extern __shared__ half mem_shared[]; extern __shared__ half mem_shared[];
f16_t *A_shared = reinterpret_cast<f16_t *>(mem_shared); f16_t *A_shared = reinterpret_cast<f16_t *>(mem_shared);
f16_t *B_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA); f16_t *B_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA);
f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB); f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB);
f16_t *zeros_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales); f16_t *zeros_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
float *C_shared = reinterpret_cast<float *>(mem_shared); float *C_shared = reinterpret_cast<float *>(mem_shared);
f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE]; f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE]; f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE]; f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M; int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N; int cta_offset_n = blockIdx_n * CTA_N;
int cta_offset_k = blockIdx_z * (K / SPLITK); int cta_offset_k = blockIdx_z * (K / SPLITK);
int warp_mn = threadIdx.y % NUM_WARPS_MN; int warp_mn = threadIdx.y % NUM_WARPS_MN;
int slice_id = threadIdx.y / NUM_WARPS_MN; int slice_id = threadIdx.y / NUM_WARPS_MN;
int warp_offset_n = (warp_mn % (CTA_N / WARP_N)) * WARP_N; int warp_offset_n = (warp_mn % (CTA_N / WARP_N)) * WARP_N;
int warp_offset_m = (warp_mn / (CTA_N / WARP_N)) * WARP_M; int warp_offset_m = (warp_mn / (CTA_N / WARP_N)) * WARP_M;
int warp_offset_k = slice_id * WARP_K; int warp_offset_k = slice_id * WARP_K;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++) for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
C_warp[i] = 0.0; C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K / SPLITK; int gemm_iters = (K + CTA_K - 1) / CTA_K / SPLITK;
int k_0_0_ld = 0; int k_0_0_ld = 0;
int k_0_0 = 0; int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1; constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
#pragma unroll #pragma unroll
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) {
{ global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A,
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true); A_shared +
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true); k_0_0_ld * kSmemSizeAPerStage,
global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>( M,
scales, scales_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N, K,
zeros, zeros_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N, cta_offset_m,
N, cta_offset_m, cta_offset_n, cta_offset_k, cta_offset_n,
k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0); cta_offset_k,
k_0_0_ld,
0,
true);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B,
B_shared +
k_0_0_ld * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
0,
true);
global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales,
scales_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
zeros,
zeros_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
N,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
0,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1) if constexpr (STAGES > 1)
__pipeline_commit(); __pipeline_wait_prior(STAGES - 2);
} __syncthreads();
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2); share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
__syncthreads(); A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared,
share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0); scales_shared,
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0); zeros_shared,
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K; B_shared_warp_tmp_[0],
B_shared_warp_[0],
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) warp_offset_m,
{ warp_offset_n,
int ld_stage = k_0_0_ld % STAGES; warp_offset_k,
int compute_stage = k_0_0 % STAGES; 0);
f16_t *A_shared_this_compute_stage; constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage; for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) {
f16_t *zeros_shared_this_compute_stage; int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
#pragma unroll #pragma unroll
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) {
{ A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage; B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage; scales_shared_this_compute_stage =
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N; scales_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N; zeros_shared_this_compute_stage =
share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); zeros_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
if ((iter_k + 1) % kInterleave == 0) share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
{ A_shared_this_compute_stage,
if (compute_stage % 2 == 1) A_shared_warp_[(iter_k + 1) % 2],
{ warp_offset_m,
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>( warp_offset_n,
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, warp_offset_k,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2], (iter_k + 1) % SHARED_K_ITERS);
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); if ((iter_k + 1) % kInterleave == 0) {
} if (compute_stage % 2 == 1) {
else share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
{ B_shared_this_compute_stage,
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>( scales_shared_this_compute_stage,
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2], B_shared_warp_tmp_[1],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); B_shared_warp_[((iter_k + 1) / 2) % 2],
} warp_offset_m,
} warp_offset_n,
else warp_offset_k,
{ (iter_k + 1) % SHARED_K_ITERS);
if (compute_stage % 2 == 1) } else {
{ share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>( B_shared_this_compute_stage,
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, scales_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2], zeros_shared_this_compute_stage,
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); B_shared_warp_tmp_[0],
} B_shared_warp_[((iter_k + 1) / 2) % 2],
else warp_offset_m,
{ warp_offset_n,
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>( warp_offset_k,
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, (iter_k + 1) % SHARED_K_ITERS);
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2], }
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); } else {
} if (compute_stage % 2 == 1) {
} share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2]; B_shared_this_compute_stage,
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2]; scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) B_shared_warp_tmp_[1],
{ B_shared_warp_[((iter_k + 1) / 2) % 2],
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) warp_offset_m,
{ warp_offset_n,
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); warp_offset_k,
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); (iter_k + 1) % SHARED_K_ITERS);
} } else {
} share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
if (iter_k < WARP_K / INTRIN_K - 1) scales_shared_this_compute_stage,
{ zeros_shared_this_compute_stage,
if constexpr (STAGES == 1) B_shared_warp_tmp_[0],
__syncthreads(); B_shared_warp_[((iter_k + 1) / 2) % 2],
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters); warp_offset_m,
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters); warp_offset_n,
} warp_offset_k,
(iter_k + 1) % SHARED_K_ITERS);
if (iter_k == WARP_K / INTRIN_K - 2) }
{ }
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
{ f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
__syncthreads();
} for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) {
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters); for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) {
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters); mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8,
global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>( A_shared_warp + i_0_3 * 8,
scales, scales_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
zeros, zeros_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N, mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4,
N, cta_offset_m, cta_offset_n, cta_offset_k, A_shared_warp + i_0_3 * 8,
k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0); B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
if constexpr (STAGES > 1) }
{ }
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2); if (iter_k < WARP_K / INTRIN_K - 1) {
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
A,
A_shared + ld_stage * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
B,
B_shared + ld_stage * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2) {
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) {
__syncthreads();
}
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
A,
A_shared + ld_stage * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k + 1,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
B,
B_shared + ld_stage * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k + 1,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales,
scales_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
zeros,
zeros_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
N,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1) {
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
} }
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
} }
} __pipeline_commit();
__pipeline_commit(); __pipeline_wait_prior(0);
__pipeline_wait_prior(0); __syncthreads();
__syncthreads(); if constexpr (SLICES > 1) {
if constexpr (SLICES > 1)
{
#pragma unroll #pragma unroll
for (int z = 0; z < SLICES; ++z) for (int z = 0; z < SLICES; ++z) {
{ if (slice_id == z) {
if (slice_id == z)
{
#pragma unroll #pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
{
#pragma unroll #pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
{
#pragma unroll #pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
{ if (z > 0) {
if (z > 0) C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] +=
{ C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n +
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2]; ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N +
} (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]; }
}; C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
} ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 +
(local_id % 2) + (threadIdx.x % 4) * 2] =
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
};
}
}
}
__syncthreads();
} }
} if (slice_id == 0) {
__syncthreads();
}
if (slice_id == 0)
{
#pragma unroll #pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
{
#pragma unroll #pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
{
#pragma unroll #pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
{ C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] =
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2]; C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
}; ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 +
(local_id % 2) + (threadIdx.x % 4) * 2];
};
}
}
} }
}
} }
}
if (slice_id == 0) if (slice_id == 0) {
{ Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x);
Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x);
if constexpr (SPLITK > 1) if constexpr (SPLITK > 1) {
{ semaphore.fetch();
semaphore.fetch(); }
}
if (blockIdx_z != 0) if (blockIdx_z != 0) {
{ semaphore.wait(blockIdx_z);
semaphore.wait(blockIdx_z); for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
{ for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M +
{ ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{ if (write_row < M) {
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); f162_t *existing_psum_ptr = reinterpret_cast<f162_t *>(
C + write_row * N + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + (local_id / 4) * 8 +
if (write_row < M) (local_id % 2) + (threadIdx.x % 4) * 2);
{
f162_t *existing_psum_ptr = reinterpret_cast<f162_t *>( *existing_psum_ptr =
C + write_row * N + __hadd2(*existing_psum_ptr,
cta_offset_n + warp_offset_n + ax1_0_1 * 16 + cuda_cast<f162_t>(*reinterpret_cast<float2 *>(
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2); C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id)));
}
*existing_psum_ptr = __hadd2( };
*existing_psum_ptr, }
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(
C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id)));
} }
}; } else {
} for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
} for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
} for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
else int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M +
{ ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) if (write_row < M) {
{ *reinterpret_cast<f162_t *>(C + write_row * N + cta_offset_n + warp_offset_n +
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) ax1_0_1 * 16 + (local_id / 4) * 8 + (local_id % 2) +
{ (threadIdx.x % 4) * 2) =
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
{ ax1_0_1 * 8 + local_id));
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); }
if (write_row < M) };
{ }
*reinterpret_cast<f162_t *>(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
} }
};
} }
}
}
if constexpr (SPLITK > 1) if constexpr (SPLITK > 1) {
{
int lock = 0; int lock = 0;
if (SPLITK == blockIdx_z + 1) if (SPLITK == blockIdx_z + 1) {
{
lock = 0; lock = 0;
} } else {
else lock = blockIdx_z + 1;
{ }
lock = blockIdx_z + 1; semaphore.release(lock);
} }
semaphore.release(lock);
} }
}
} }
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES> template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A_T2(f16_t *src, f16_t *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) __device__ __inline__ void global_to_share_one_stage_A_T2(f16_t *src,
{ f16_t *dst,
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS; int global_nrows,
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; int global_ncols,
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used; int cta_offset_m,
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS; int cta_offset_n,
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K; int global_iter_k,
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; int shared_iter_k,
constexpr int threads_per_row = CTA_K / PACK_SIZE; bool mask) {
constexpr int kSmemCol = CTA_K + SMEM_PAD_A; constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
int ld_col = (threadIdx.x % threads_per_row); constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
#pragma unroll #pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
{ int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int global_iter = shared_iter_k * partial_global_iters + _global_iter; int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row); int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE; void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled); uint4 *src_ptr =
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE); (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE +
if constexpr (STAGES > 1) global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n *
{ // global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols +
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); // (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows)); // + (threadIdx.x % threads_per_row) * PACK_SIZE);
} if constexpr (STAGES > 1) {
else uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
{ cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
if (local_mask & (ld_row + cta_offset_m < global_nrows)) } else {
*(uint4 *)dst_ptr = *src_ptr; if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
}
} }
}
} }
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES> template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B_T2(f16_t *src, f16_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) __device__ __inline__ void global_to_share_one_stage_B_T2(f16_t *src,
{ f16_t *dst,
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS; int global_ncols,
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; int cta_offset_m,
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used; int cta_offset_n,
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS; int global_iter_k,
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K; int shared_iter_k,
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; bool mask) {
constexpr int threads_per_row = CTA_K / PACK_SIZE; constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
#pragma unroll #pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
{ int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row); int ld_col = (threadIdx.x % threads_per_row);
int ld_col = (threadIdx.x % threads_per_row); int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7; void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE)); uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols +
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE); ld_row * global_ncols + ld_col * PACK_SIZE);
if constexpr (STAGES > 1) if constexpr (STAGES > 1) {
{ uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); cp_async_cg_A(addr, src_ptr, local_mask);
cp_async_cg_A(addr, src_ptr, local_mask); } else {
} if (local_mask)
else *(uint4 *)dst_ptr = *src_ptr;
{ }
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
} }
}
} }
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G> template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales_T2(f16_t *src, f16_t *dst, f16_t *src_z, f16_t *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) __device__ __inline__ void global_to_share_one_stage_scales_T2(f16_t *src,
{ f16_t *dst,
constexpr int threads_needed = CTA_N / PACK_SIZE / 1; f16_t *src_z,
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; f16_t *dst_z,
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used; int global_ncols,
constexpr int threads_per_row = CTA_N / PACK_SIZE; int cta_offset_m,
constexpr int kSmemCol = CTA_N; int cta_offset_n,
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); int global_iter_k,
int g_idx = global_iter_k * CTA_K / G; int shared_iter_k,
bool mask) {
void *dst_ptr = (void *)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE); constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE); constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE); constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE); constexpr int threads_per_row = CTA_N / PACK_SIZE;
if (STAGES > 1) constexpr int kSmemCol = CTA_N;
{ bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); int g_idx = global_iter_k * CTA_K / G;
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z); void *dst_ptr = (void *)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
cp_async_cg_A(addr_z, src_ptr_z, local_mask); uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
} void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE);
else uint4 *src_ptr_z =
{ (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (local_mask) if (STAGES > 1) {
{ uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
*(uint4 *)dst_ptr = *src_ptr; cp_async_cg_A(addr, src_ptr, local_mask);
*(uint4 *)dst_ptr_z = *src_ptr_z; uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
} else {
if (local_mask) {
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
}
} }
}
} }
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters> template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void share_to_reg_one_stage_A_T2(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int k_0_1) __device__ __inline__ void
{ share_to_reg_one_stage_A_T2(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int k_0_1) {
constexpr int kSmemCol = CTA_K + SMEM_PAD_A; constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
{
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16); int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8; int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE; int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled); void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
} }
} }
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G> template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B_T2(f16_t *src, f16_t *src_scales, f16_t *src_zeros, f16_t *dst, f16_t *dst_fp16, int warp_offset_m, int warp_offset_n, int k_0_1) __device__ __inline__ void share_to_reg_one_stage_B_T2(f16_t *src,
{ f16_t *src_scales,
using f162_t = typename packed_as<f16_t, 2>::type; f16_t *src_zeros,
constexpr int kSmemCol = CTA_K + SMEM_PAD_B; f16_t *dst,
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8); f16_t *dst_fp16,
int c0 = ((threadIdx.x / 8) % 2) * 8; int warp_offset_m,
int r = r0 / 4; int warp_offset_n,
int c = (r0 % 4) * 16 + c0; int k_0_1) {
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE; using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
if constexpr (ldmatrix) int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
{ int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix) {
#pragma unroll #pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
{ void *addr_ptr =
void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled); (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol +
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); k_0_1 * 16 + r * kSmemCol + c_swizzled);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
} }
}
#pragma unroll #pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
{ f16_t scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f16_t scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; f16_t zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f16_t zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; f162_t scale2 = f162f162(scale);
f162_t scale2 = f162f162(scale); f162_t zero2 = f162f162(zero);
f162_t zero2 = f162f162(zero); f162_t loaded[4];
f162_t loaded[4]; dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8),
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded)); reinterpret_cast<uint4 *>(loaded));
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) for (int i = 0; i < 4; i++) {
{ loaded[i] = __hfma2(loaded[i], scale2, zero2);
loaded[i] = __hfma2(loaded[i], scale2, zero2); }
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
} }
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
}
} }
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G> template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
__global__ void gemm_w4a16_T2(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_t *__restrict__ scales, f16_t *__restrict__ zeros, f16_t *__restrict__ C, int M, int N, int K) __global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
{ f16_t *__restrict__ B,
f16_t *__restrict__ scales,
f16_t *__restrict__ zeros,
f16_t *__restrict__ C,
int M,
int N,
int K) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch(); trap_unsupported_arch();
return; return;
#endif #endif
using f162_t = typename packed_as<f16_t, 2>::type; using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N; constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE; constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
int num_blocks_n = (N + CTA_N - 1) / CTA_N; int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M; int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0; int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n); int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n); int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N); const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile); int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile); int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile); const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x; blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y; blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE]; float C_warp[CTA_M * CTA_N / CTA_SIZE];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A; constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B; constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA; constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB; constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES; constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES; constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int kSmemSizeScales = CTA_N * STAGES / 2; constexpr int kSmemSizeScales = CTA_N * STAGES / 2;
constexpr int kSmemSizeZeros = CTA_N * STAGES / 2; constexpr int kSmemSizeZeros = CTA_N * STAGES / 2;
constexpr int scales_load_interval = G / CTA_K; constexpr int scales_load_interval = G / CTA_K;
extern __shared__ half mem_shared[]; extern __shared__ half mem_shared[];
f16_t *A_shared = reinterpret_cast<f16_t *>(mem_shared); f16_t *A_shared = reinterpret_cast<f16_t *>(mem_shared);
f16_t *B_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA); f16_t *B_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA);
f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB); f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB);
f16_t *zeros_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales); f16_t *zeros_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE]; f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE]; f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE]; f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M; int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N; int cta_offset_n = blockIdx_n * CTA_N;
int warp_offset_m = (threadIdx.y % (CTA_M / WARP_M)) * WARP_M; int warp_offset_m = (threadIdx.y % (CTA_M / WARP_M)) * WARP_M;
int warp_offset_n = (threadIdx.y / (CTA_M / WARP_M)) * WARP_N; int warp_offset_n = (threadIdx.y / (CTA_M / WARP_M)) * WARP_N;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE; i++) for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE; i++)
C_warp[i] = 0.0; C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K; int gemm_iters = (K + CTA_K - 1) / CTA_K;
int k_0_0_ld = 0; int k_0_0_ld = 0;
int k_0_0 = 0; int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1; constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
#pragma unroll #pragma unroll
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) {
{ global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true); A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true); global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
global_to_share_one_stage_scales_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>( B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
scales, scales_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
zeros, zeros_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
{
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
{
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval) * CTA_N;
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0)
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
}
else
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
}
__syncthreads();
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
{
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
{
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1)
{
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2)
{
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2)
{
__syncthreads();
}
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>( global_to_share_one_stage_scales_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (ld_stage / scales_load_interval) * CTA_N, scales,
zeros, zeros_shared + (ld_stage / scales_load_interval) * CTA_N, scales_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
N, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0); zeros,
zeros_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
N,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
0,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1) if constexpr (STAGES > 1)
{ __pipeline_commit();
__pipeline_commit(); }
__pipeline_wait_prior(STAGES - 2); if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared,
scales_shared,
zeros_shared,
B_shared_warp_tmp_[0],
B_shared_warp_[0],
warp_offset_m,
warp_offset_n,
0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) {
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) {
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval) * CTA_N;
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
A_shared_this_compute_stage,
A_shared_warp_[(iter_k + 1) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0) {
if (compute_stage % 2 == 1) {
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
} else {
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
}
} else {
if (compute_stage % 2 == 1) {
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
} else {
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
}
}
__syncthreads();
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) {
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) {
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8,
A_shared_warp + i_0_3 * 8,
B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4,
A_shared_warp + i_0_3 * 8,
B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1) {
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
A,
A_shared + ld_stage * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
B,
B_shared + ld_stage * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2) {
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) {
__syncthreads();
}
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
A,
A_shared + ld_stage * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k + 1,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
B,
B_shared + ld_stage * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k + 1,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales,
scales_shared + (ld_stage / scales_load_interval) * CTA_N,
zeros,
zeros_shared + (ld_stage / scales_load_interval) * CTA_N,
N,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1) {
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
} }
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
} }
} for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
{ for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) int write_row =
{ cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) if (write_row < M) {
{ *reinterpret_cast<f162_t *>(C + write_row * N + cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
if (write_row < M) cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
{ ax1_0_1 * 8 + local_id));
*reinterpret_cast<f162_t *>( }
C + write_row * N + };
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
} }
};
} }
}
} }
Tensor awq_gemm_forward_cuda( Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, Tensor _zeros) {
Tensor _in_feats, auto output_shape = _in_feats.shape.dataExtent;
Tensor _kernel, output_shape.back() = _kernel.size(0) * kInterleave;
Tensor _scales, int num_in_feats = _in_feats.numel() / _in_feats.size(-1);
Tensor _zeros) int num_in_channels = _in_feats.size(-1);
{ auto options = Tensor::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
auto output_shape = _in_feats.shape.dataExtent; auto options_int = Tensor::TensorOptions().dtype(Tensor::INT32).device(_in_feats.device());
output_shape.back() = _kernel.size(0) * kInterleave; Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
int num_in_feats = _in_feats.numel() / _in_feats.size(-1); int num_out_feats = _out_feats.numel() / _out_feats.size(-1);
int num_in_channels = _in_feats.size(-1); int num_out_channels = _out_feats.size(-1);
auto options =
Tensor::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); if (_in_feats.scalar_type() == Tensor::FP16) {
auto options_int = using f16_t = half;
Tensor::TensorOptions().dtype(Tensor::INT32).device(_in_feats.device());
Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device()); auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
int num_out_feats = _out_feats.numel() / _out_feats.size(-1); auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
int num_out_channels = _out_feats.size(-1); auto scales = reinterpret_cast<f16_t *>(_scales.data_ptr());
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr());
if (_in_feats.scalar_type() == Tensor::FP16) auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());
{
using f16_t = half; if (num_out_feats <= 32) {
constexpr int G = 128;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr()); constexpr int CTA_M = 16;
auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>()); constexpr int CTA_N = 128;
auto scales = reinterpret_cast<f16_t *>(_scales.data_ptr()); constexpr int CTA_K = 128;
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr()); constexpr int WARP_M = 16;
auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr()); constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
if (num_out_feats <= 32) constexpr int SPLITK = 2;
{ constexpr int STAGES = 4;
constexpr int G = 128; KERNEL_LAUNCH_CODE
constexpr int CTA_M = 16; } else if (num_out_feats <= 64) {
constexpr int CTA_N = 128;
constexpr int CTA_K = 128; constexpr int G = 128;
constexpr int WARP_M = 16; constexpr int CTA_M = 16;
constexpr int WARP_N = 32; constexpr int CTA_N = 128;
constexpr int WARP_K = 64; constexpr int CTA_K = 128;
constexpr int SPLITK = 2; constexpr int WARP_M = 16;
constexpr int STAGES = 4; constexpr int WARP_N = 32;
KERNEL_LAUNCH_CODE constexpr int WARP_K = 64;
} constexpr int SPLITK = 1;
else if (num_out_feats <= 64) constexpr int STAGES = 3;
{ KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 128) {
constexpr int G = 128; constexpr int G = 128;
constexpr int CTA_M = 16; constexpr int CTA_M = 32;
constexpr int CTA_N = 128; constexpr int CTA_N = 128;
constexpr int CTA_K = 128; constexpr int CTA_K = 128;
constexpr int WARP_M = 16; constexpr int WARP_M = 32;
constexpr int WARP_N = 32; constexpr int WARP_N = 32;
constexpr int WARP_K = 64; constexpr int WARP_K = 64;
constexpr int SPLITK = 1; constexpr int SPLITK = 1;
constexpr int STAGES = 3; constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE KERNEL_LAUNCH_CODE
} } else if (num_out_feats <= 192) {
else if (num_out_feats <= 128) constexpr int G = 128;
{ constexpr int CTA_M = 64;
constexpr int G = 128; constexpr int CTA_N = 128;
constexpr int CTA_M = 32; constexpr int CTA_K = 64;
constexpr int CTA_N = 128; constexpr int WARP_M = 64;
constexpr int CTA_K = 128; constexpr int WARP_N = 32;
constexpr int WARP_M = 32; constexpr int WARP_K = 64;
constexpr int WARP_N = 32; constexpr int SPLITK = 1;
constexpr int WARP_K = 64; constexpr int STAGES = 4;
constexpr int SPLITK = 1; KERNEL_LAUNCH_CODE
constexpr int STAGES = 4; } else {
KERNEL_LAUNCH_CODE constexpr int G = 128;
} constexpr int CTA_M = 64;
else if (num_out_feats <= 192) constexpr int CTA_N = 128;
{ constexpr int CTA_K = 64;
constexpr int G = 128; constexpr int WARP_M = 64;
constexpr int CTA_M = 64; constexpr int WARP_N = 32;
constexpr int CTA_N = 128; constexpr int WARP_K = 64;
constexpr int CTA_K = 64; constexpr int STAGES = 4;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32; constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int WARP_K = 64; constexpr int kSmemByteSize =
constexpr int SPLITK = 1; (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES *
constexpr int STAGES = 4; sizeof(f16_t);
KERNEL_LAUNCH_CODE if (kSmemByteSize >= 99 * 1024) {
} printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
else return _out_feats;
{ }
constexpr int G = 128; int j_factors1 = num_out_channels / CTA_N / 1;
constexpr int CTA_M = 64; dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
constexpr int CTA_N = 128; dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
constexpr int CTA_K = 64; auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
constexpr int WARP_M = 64; cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
constexpr int WARP_N = 32; kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
constexpr int WARP_K = 64; in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
constexpr int STAGES = 4; }
} else if (_in_feats.scalar_type() == Tensor::BF16) {
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N); using f16_t = __nv_bfloat16;
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(f16_t);
if (kSmemByteSize >= 99 * 1024) auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
{ auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); auto scales = reinterpret_cast<f16_t *>(_scales.data_ptr());
return _out_feats; auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr());
} auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1); if (num_out_feats <= 32) {
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); constexpr int G = 128;
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>; constexpr int CTA_M = 16;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); constexpr int CTA_N = 128;
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>( constexpr int CTA_K = 128;
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels); constexpr int WARP_M = 16;
} constexpr int WARP_N = 32;
} constexpr int WARP_K = 64;
else if (_in_feats.scalar_type() == Tensor::BF16) constexpr int SPLITK = 2;
{ constexpr int STAGES = 4;
using f16_t = __nv_bfloat16; KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 64) {
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>()); constexpr int G = 128;
auto scales = reinterpret_cast<f16_t *>(_scales.data_ptr()); constexpr int CTA_M = 16;
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr()); constexpr int CTA_N = 128;
auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr()); constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
if (num_out_feats <= 32) constexpr int WARP_N = 32;
{ constexpr int WARP_K = 64;
constexpr int G = 128; constexpr int SPLITK = 1;
constexpr int CTA_M = 16; constexpr int STAGES = 3;
constexpr int CTA_N = 128; KERNEL_LAUNCH_CODE
constexpr int CTA_K = 128; } else if (num_out_feats <= 128) {
constexpr int WARP_M = 16; constexpr int G = 128;
constexpr int WARP_N = 32; constexpr int CTA_M = 32;
constexpr int WARP_K = 64; constexpr int CTA_N = 128;
constexpr int SPLITK = 2; constexpr int CTA_K = 128;
constexpr int STAGES = 4; constexpr int WARP_M = 32;
KERNEL_LAUNCH_CODE constexpr int WARP_N = 32;
} constexpr int WARP_K = 64;
else if (num_out_feats <= 64) constexpr int SPLITK = 1;
{ constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
constexpr int G = 128; } else if (num_out_feats <= 192) {
constexpr int CTA_M = 16; constexpr int G = 128;
constexpr int CTA_N = 128; constexpr int CTA_M = 64;
constexpr int CTA_K = 128; constexpr int CTA_N = 128;
constexpr int WARP_M = 16; constexpr int CTA_K = 64;
constexpr int WARP_N = 32; constexpr int WARP_M = 64;
constexpr int WARP_K = 64; constexpr int WARP_N = 32;
constexpr int SPLITK = 1; constexpr int WARP_K = 64;
constexpr int STAGES = 3; constexpr int SPLITK = 1;
KERNEL_LAUNCH_CODE constexpr int STAGES = 4;
} KERNEL_LAUNCH_CODE
else if (num_out_feats <= 128) } else {
{ constexpr int G = 128;
constexpr int G = 128; constexpr int CTA_M = 64;
constexpr int CTA_M = 32; constexpr int CTA_N = 128;
constexpr int CTA_N = 128; constexpr int CTA_K = 64;
constexpr int CTA_K = 128; constexpr int WARP_M = 64;
constexpr int WARP_M = 32; constexpr int WARP_N = 32;
constexpr int WARP_N = 32; constexpr int WARP_K = 64;
constexpr int WARP_K = 64; constexpr int STAGES = 4;
constexpr int SPLITK = 1;
constexpr int STAGES = 4; constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
KERNEL_LAUNCH_CODE constexpr int kSmemByteSize =
} (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES *
else if (num_out_feats <= 192) sizeof(f16_t);
{ if (kSmemByteSize >= 99 * 1024) {
constexpr int G = 128; printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
constexpr int CTA_M = 64; return _out_feats;
constexpr int CTA_N = 128; }
constexpr int CTA_K = 64; int j_factors1 = num_out_channels / CTA_N / 1;
constexpr int WARP_M = 64; dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
constexpr int WARP_N = 32; dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
constexpr int WARP_K = 64; auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
constexpr int SPLITK = 1; cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
constexpr int STAGES = 4; kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
KERNEL_LAUNCH_CODE in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
} }
else } else {
{ throw std::runtime_error("Unsupported input type");
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(f16_t);
if (kSmemByteSize >= 99 * 1024)
{
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
return _out_feats;
}
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
} }
}
else return _out_feats;
{ }
throw std::runtime_error("Unsupported input type");
}
return _out_feats;
}
\ No newline at end of file
...@@ -3,9 +3,4 @@ ...@@ -3,9 +3,4 @@
#include "common.h" #include "common.h"
#include "Tensor.h" #include "Tensor.h"
Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, Tensor _zeros);
Tensor awq_gemm_forward_cuda(
Tensor _in_feats,
Tensor _kernel,
Tensor _scales,
Tensor _zeros);
/* /*
* Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv) * Modified from NVIDIA
* [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -39,106 +40,98 @@ ...@@ -39,106 +40,98 @@
#define MEM_ACCESS_SIZE 128 #define MEM_ACCESS_SIZE 128
// Reduce sum within the warp using the tree reduction algorithm. // Reduce sum within the warp using the tree reduction algorithm.
template <typename float_t, int Num, int WarpSize> template<typename float_t, int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(float_t* psum, float (*out_smem)[Num * 4]) __device__ __forceinline__ static void warp_reduce(float_t *psum, float (*out_smem)[Num * 4]) {
{ // kInterleave = 4
// kInterleave = 4 float fpsum[Num];
float fpsum[Num]; #pragma unroll
#pragma unroll for (int i = 0; i < Num; ++i) {
for (int i = 0; i < Num; ++i) fpsum[i] = static_cast<float>(psum[i]);
{ }
fpsum[i] = static_cast<float>(psum[i]);
} #pragma unroll
for (int i = 0; i < Num; ++i) {
#pragma unroll // T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
for (int i = 0; i < Num; ++i) fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
{ fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4) fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16); }
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8); __syncthreads();
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1); int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
} if (lane == 0 || lane == 2 || lane == 4 || lane == 6) {
__syncthreads(); #pragma unroll
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; for (int i = 0; i < Num; ++i) {
if (lane == 0 || lane == 2 || lane == 4 || lane == 6) out_smem[warp][i * 4 + lane / 2] = fpsum[i];
{ }
#pragma unroll }
for (int i = 0; i < Num; ++i) __syncthreads();
{
out_smem[warp][i * 4 + lane / 2] = fpsum[i];
}
}
__syncthreads();
}; };
__device__ __forceinline__ int make_divisible(int c, int divisor){ __device__ __forceinline__ int make_divisible(int c, int divisor) {
return (c + divisor - 1) / divisor; return (c + divisor - 1) / divisor;
} }
template<typename half_t> template<typename half_t>
__device__ __forceinline__ __device__ __forceinline__ packed_as<half_t, 2>::type half2half2(half_t x);
packed_as<half_t, 2>::type half2half2(half_t x);
template<> template<>
__device__ __forceinline__ __device__ __forceinline__ packed_as<half, 2>::type half2half2<half>(half x) {
packed_as<half, 2>::type half2half2<half>(half x) {
return __half2half2(x); return __half2half2(x);
} }
template<> template<>
__device__ __forceinline__ __device__ __forceinline__ packed_as<__nv_bfloat16, 2>::type half2half2<__nv_bfloat16>(__nv_bfloat16 x) {
packed_as<__nv_bfloat16, 2>::type half2half2<__nv_bfloat16>(__nv_bfloat16 x) {
return __bfloat162bfloat162(x); return __bfloat162bfloat162(x);
} }
template<typename T> template<typename T>
__device__ __forceinline__ __device__ __forceinline__ float2 half22float2(T val);
float2 half22float2(T val);
template<> template<>
__device__ __forceinline__ __device__ __forceinline__ float2 half22float2<half2>(half2 val) {
float2 half22float2<half2>(half2 val) {
return __half22float2(val); return __half22float2(val);
} }
template<> template<>
__device__ __forceinline__ __device__ __forceinline__ float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
return __bfloat1622float2(val); return __bfloat1622float2(val);
} }
template <typename half_t, int NPerBlock, int Batch, int BlockSize, int GroupSize> template<typename half_t, int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel( __global__ void gemv_kernel(const half_t *inputs,
const half_t* inputs, const uint32_t* weight, const half_t* scales, const half_t* zeros, half_t* outputs, const uint32_t *weight,
const int IC, const int OC) const half_t *scales,
{ const half_t *zeros,
half_t *outputs,
const int IC,
const int OC) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
if constexpr(std::is_same_v<half_t, __nv_bfloat16>) { if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
trap_unsupported_arch(); trap_unsupported_arch();
return; return;
} }
#endif #endif
using half2_t = typename packed_as<half_t, 2>::type; using half2_t = typename packed_as<half_t, 2>::type;
using accum_t = float; using accum_t = float;
using accum2_t = typename packed_as<accum_t, 2>::type; using accum2_t = typename packed_as<accum_t, 2>::type;
const int kStride = 64; const int kStride = 64;
const int kElemsPerThread = MEM_ACCESS_SIZE / 4; const int kElemsPerThread = MEM_ACCESS_SIZE / 4;
const int kThreadsNumPerTile = kStride / kElemsPerThread; const int kThreadsNumPerTile = kStride / kElemsPerThread;
// assert(MEM_ACCESS_SIZE == 128); // assert(MEM_ACCESS_SIZE == 128);
// static constexpr int kShuffleSize = 32; // static constexpr int kShuffleSize = 32;
static constexpr int kShuffleBasicTile = 2; static constexpr int kShuffleBasicTile = 2;
static constexpr int kShuffleContinous = 4; static constexpr int kShuffleContinous = 4;
static constexpr int kShuffleStrided = 4; static constexpr int kShuffleStrided = 4;
constexpr int Num = NPerBlock * Batch; constexpr int Num = NPerBlock * Batch;
constexpr int kInterleave = 4; constexpr int kInterleave = 4;
alignas(16) half_t local_inputs[kElemsPerThread]; alignas(16) half_t local_inputs[kElemsPerThread];
alignas(16) uint32_t local_qweights[MEM_ACCESS_SIZE / 32]; alignas(16) uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
alignas(16) half_t half_weight_buffer[kElemsPerThread]; alignas(16) half_t half_weight_buffer[kElemsPerThread];
alignas(16) half_t dequantized_weight[kElemsPerThread * NPerBlock]; alignas(16) half_t dequantized_weight[kElemsPerThread * NPerBlock];
alignas(16) half_t local_scale[NPerBlock]; alignas(16) half_t local_scale[NPerBlock];
alignas(16) half_t local_scaled_zeros[NPerBlock]; alignas(16) half_t local_scaled_zeros[NPerBlock];
...@@ -146,7 +139,7 @@ __global__ void gemv_kernel( ...@@ -146,7 +139,7 @@ __global__ void gemv_kernel(
accum_t psum[Num]; accum_t psum[Num];
for (int i = 0; i < Num; ++i) for (int i = 0; i < Num; ++i)
psum[i] = static_cast<accum_t>(0.f); psum[i] = static_cast<accum_t>(0.f);
// extern __shared__ uint8_t shmem[]; // extern __shared__ uint8_t shmem[];
// float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem); // float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
...@@ -154,80 +147,67 @@ __global__ void gemv_kernel( ...@@ -154,80 +147,67 @@ __global__ void gemv_kernel(
const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave; const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave; const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride +
+ (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread; (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
const int group_offset = act_k_offset / GroupSize; const int group_offset = act_k_offset / GroupSize;
// TODO: use make_divisible // TODO: use make_divisible
const uint32_t* blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR; const uint32_t *blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
const half_t* scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC; const half_t *scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;
const half_t* zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC; const half_t *zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
const half_t* inputs_ptr = inputs + act_k_offset; const half_t *inputs_ptr = inputs + act_k_offset;
const int act_forward_step = BlockSize * kElemsPerThread / kInterleave; const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;
const int scale_forward_step = act_forward_step / GroupSize * OC; const int scale_forward_step = act_forward_step / GroupSize * OC;
// Main loop iteration, each block completes the outputs for several OCs // Main loop iteration, each block completes the outputs for several OCs
for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread) for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread) {
{ // Load qweight, scales and scaled_zeros
// Load qweight, scales and scaled_zeros #pragma unroll
#pragma unroll for (int idx = 0; idx < NPerBlock; ++idx) {
for (int idx = 0; idx < NPerBlock; ++idx)
{
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit) // use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
*((float4*)(local_qweights)) = *((float4 *)(local_qweights)) = *((float4 *)(blk_weight_ptr + (idx * kInterleave * IC + kk) / PACK_FACTOR));
*((float4*)(blk_weight_ptr + (idx * kInterleave * IC + kk)/ PACK_FACTOR)); local_scale[idx] = *(scale_ptr + idx * kInterleave);
local_scale[idx] = *(scale_ptr + idx * kInterleave); local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
// Map int4 qweight to fp format
// Map int4 qweight to fp format #pragma unroll
#pragma unroll for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i) {
for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i)
{
// Converts 32 bits (8 x int4) to 8 fp16 // Converts 32 bits (8 x int4) to 8 fp16
dequantize_s4_to_fp16x2(*reinterpret_cast<half2_t *>(local_qweights + i), reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR)); dequantize_s4_to_fp16x2(*reinterpret_cast<half2_t *>(local_qweights + i),
reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
} }
// Dequantize (apply s/z) and shuffle elements to match the weight packing format // Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll #pragma unroll
for (int i = 0; i < kShuffleContinous; ++i) for (int i = 0; i < kShuffleContinous; ++i) {
{ #pragma unroll
#pragma unroll for (int j = 0; j < kShuffleStrided; ++j) {
for (int j = 0; j < kShuffleStrided; ++j) half2_t w = *reinterpret_cast<half2_t *>(half_weight_buffer +
{ (i + j * kShuffleContinous) * kShuffleBasicTile);
half2_t w = w = __hfma2(w, half2half2(local_scale[idx]), half2half2(local_scaled_zeros[idx]));
*reinterpret_cast<half2_t*>( dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0) * NPerBlock + idx] = w.x;
half_weight_buffer + (i + j * kShuffleContinous)* kShuffleBasicTile dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1) * NPerBlock + idx] = w.y;
);
w = __hfma2(w, half2half2(local_scale[idx]), half2half2(local_scaled_zeros[idx]));
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0)
* NPerBlock + idx]
= w.x;
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1)
* NPerBlock + idx]
= w.y;
} }
} }
} }
#pragma unroll #pragma unroll
for (int batch_idx = 0; batch_idx < Batch; ++batch_idx) for (int batch_idx = 0; batch_idx < Batch; ++batch_idx) {
{ const half_t *local_inputs_ptr = inputs_ptr + batch_idx * IC;
const half_t* local_inputs_ptr = inputs_ptr + batch_idx * IC; #pragma unroll
#pragma unroll for (int idx = 0; idx < kElemsPerThread / 8; ++idx) {
for (int idx = 0; idx < kElemsPerThread / 8; ++idx)
{
// load activation, 8 halves (128 bits) / step. // load activation, 8 halves (128 bits) / step.
*((float4*)(local_inputs + idx * 8)) = *((float4*)(local_inputs_ptr + idx * 8)); *((float4 *)(local_inputs + idx * 8)) = *((float4 *)(local_inputs_ptr + idx * 8));
} }
// Perform the MACs // Perform the MACs
#pragma unroll #pragma unroll
for (int x = 0; x < NPerBlock / 2; ++x) for (int x = 0; x < NPerBlock / 2; ++x) {
{ #pragma unroll
#pragma unroll for (int y = 0; y < kElemsPerThread; ++y) {
for (int y = 0; y < kElemsPerThread; ++y) accum2_t prod = cuda_cast<accum2_t>(
{ __hmul2(*reinterpret_cast<half2_t *>(dequantized_weight + y * NPerBlock + x * 2),
accum2_t prod = cuda_cast<accum2_t>(__hmul2(*reinterpret_cast<half2_t*>(dequantized_weight + y * NPerBlock + x * 2), half2half2(local_inputs[y]))); half2half2(local_inputs[y])));
*reinterpret_cast<accum2_t*>(psum + batch_idx * NPerBlock + x * 2) *reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2) =
= prod + *reinterpret_cast<accum2_t*>(psum + batch_idx * NPerBlock + x * 2); prod + *reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2);
// *reinterpret_cast<half2_t*>(psum + batch_idx * NPerBlock + x * 2) // *reinterpret_cast<half2_t*>(psum + batch_idx * NPerBlock + x * 2)
// = __hfma2(*reinterpret_cast<half2_t*>(dequantized_weight + y * NPerBlock + x * 2), // = __hfma2(*reinterpret_cast<half2_t*>(dequantized_weight + y * NPerBlock + x * 2),
// half2half2(local_inputs[y]), // half2half2(local_inputs[y]),
...@@ -243,13 +223,11 @@ __global__ void gemv_kernel( ...@@ -243,13 +223,11 @@ __global__ void gemv_kernel(
warp_reduce<accum_t, Num, WARP_SIZE>(psum, out_smem); warp_reduce<accum_t, Num, WARP_SIZE>(psum, out_smem);
// Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num // Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize) for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize) {
{
int batch_idx = i / (NPerBlock * kInterleave); int batch_idx = i / (NPerBlock * kInterleave);
int oc_idx = i % (NPerBlock * kInterleave); int oc_idx = i % (NPerBlock * kInterleave);
float acc = 0.f; float acc = 0.f;
for (int j = 0; j < BlockSize / WARP_SIZE; ++j) for (int j = 0; j < BlockSize / WARP_SIZE; ++j) {
{
acc += out_smem[j][i]; acc += out_smem[j][i];
} }
outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast<half_t>(acc); outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast<half_t>(acc);
...@@ -271,32 +249,24 @@ Returns: ...@@ -271,32 +249,24 @@ Returns:
out_feats: tensor of shape [B, OC]; out_feats: tensor of shape [B, OC];
*/ */
Tensor gemv_awq( Tensor gemv_awq(
Tensor _in_feats, Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int m, int n, int k, int group_size) {
Tensor _kernel,
Tensor _scaling_factors,
Tensor _zeros,
int m,
int n,
int k,
int group_size)
{
return dispatchFloat16(_scaling_factors.scalar_type(), [&]<typename half_t>() { return dispatchFloat16(_scaling_factors.scalar_type(), [&]<typename half_t>() {
assert(isTypeMatch<half_t>(_in_feats.dtype())); assert(isTypeMatch<half_t>(_in_feats.dtype()));
auto output_shape = _in_feats.shape.dataExtent; auto output_shape = _in_feats.shape.dataExtent;
output_shape.back() = n; output_shape.back() = n;
auto in_feats = reinterpret_cast<half_t*>(_in_feats.data_ptr<half_t>()); auto in_feats = reinterpret_cast<half_t *>(_in_feats.data_ptr<half_t>());
auto kernel = reinterpret_cast<uint32_t*>(_kernel.data_ptr()); auto kernel = reinterpret_cast<uint32_t *>(_kernel.data_ptr());
auto zeros = reinterpret_cast<half_t*>(_zeros.data_ptr<half_t>()); auto zeros = reinterpret_cast<half_t *>(_zeros.data_ptr<half_t>());
auto scaling_factors = reinterpret_cast<half_t*>(_scaling_factors.data_ptr<half_t>()); auto scaling_factors = reinterpret_cast<half_t *>(_scaling_factors.data_ptr<half_t>());
Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device()); Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
half_t * out_feats = reinterpret_cast<half_t *>(_out_feats.data_ptr()); half_t *out_feats = reinterpret_cast<half_t *>(_out_feats.data_ptr());
static constexpr int N_PER_BLOCK = 2; static constexpr int N_PER_BLOCK = 2;
static constexpr int K_INTERLEAVE = 4; static constexpr int K_INTERLEAVE = 4;
static constexpr int BLOCK_SIZE = 256; static constexpr int BLOCK_SIZE = 256;
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE); dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE); dim3 num_threads(BLOCK_SIZE);
...@@ -312,9 +282,9 @@ Tensor gemv_awq( ...@@ -312,9 +282,9 @@ Tensor gemv_awq(
return; return;
} }
if constexpr (M > 0) { if constexpr (M > 0) {
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>( gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE>
in_feats, kernel, scaling_factors, zeros, out_feats, k, n <<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
); in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
} }
}); });
......
...@@ -3,12 +3,5 @@ ...@@ -3,12 +3,5 @@
#include "common.h" #include "common.h"
#include "Tensor.h" #include "Tensor.h"
Tensor gemv_awq( Tensor
Tensor _in_feats, gemv_awq(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int m, int n, int k, int group_size);
Tensor _kernel,
Tensor _scaling_factors,
Tensor _zeros,
int m,
int n,
int k,
int group_size);
\ No newline at end of file
...@@ -41,65 +41,54 @@ ...@@ -41,65 +41,54 @@
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// CTA-wide semaphore for inter-CTA synchronization. /// CTA-wide semaphore for inter-CTA synchronization.
class Semaphore class Semaphore {
{
public: public:
int *lock; int *lock;
bool wait_thread; bool wait_thread;
int state; int state;
public: public:
/// Implements a semaphore to wait for a flag to reach a given value /// Implements a semaphore to wait for a flag to reach a given value
__host__ __device__ Semaphore(int *lock_, int thread_id) : lock(lock_), __host__ __device__ Semaphore(int *lock_, int thread_id)
wait_thread(thread_id < 0 || thread_id == 0), : lock(lock_), wait_thread(thread_id < 0 || thread_id == 0), state(-1) {}
state(-1)
{
}
/// Permit fetching the synchronization mechanism early /// Permit fetching the synchronization mechanism early
__device__ void fetch() __device__ void fetch() {
{ if (wait_thread) {
if (wait_thread)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else #else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif #endif
}
} }
}
/// Gets the internal state /// Gets the internal state
__device__ int get_state() const __device__ int get_state() const {
{ return state;
return state;
}
/// Waits until the semaphore is equal to the given value
__device__ void wait(int status = 0)
{
while (__syncthreads_and(state != status))
{
fetch();
} }
__syncthreads(); /// Waits until the semaphore is equal to the given value
} __device__ void wait(int status = 0) {
while (__syncthreads_and(state != status)) {
fetch();
}
__syncthreads();
}
/// Updates the lock with the given result /// Updates the lock with the given result
__device__ void release(int status = 0) __device__ void release(int status = 0) {
{ __syncthreads();
__syncthreads();
if (wait_thread) if (wait_thread) {
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else #else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#endif #endif
}
} }
}
}; };
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -16,4 +16,4 @@ inline void dispatchF16(Tensor::ScalarType type, F &&func) { ...@@ -16,4 +16,4 @@ inline void dispatchF16(Tensor::ScalarType type, F &&func) {
} else { } else {
assert(false); assert(false);
} }
} }
\ No newline at end of file
...@@ -53,17 +53,16 @@ inline auto dispatch(Tensor::ScalarType scalarType, F &&func) { ...@@ -53,17 +53,16 @@ inline auto dispatch(Tensor::ScalarType scalarType, F &&func) {
} }
#pragma nv_diagnostic push #pragma nv_diagnostic push
// warning #445-D: template parameter "scalar_t" is not used in declaring the parameter types of function template "lambda []()->auto::operator auto (*)()" // warning #445-D: template parameter "scalar_t" is not used in declaring the parameter types of function template
#pragma nv_diag_suppress 445 // "lambda []()->auto::operator auto (*)()"
#pragma nv_diag_suppress 445
template<typename T> template<typename T>
inline bool isTypeMatch(Tensor::ScalarType scalarType) { inline bool isTypeMatch(Tensor::ScalarType scalarType) {
return dispatch(scalarType, []<typename scalar_t>() { return dispatch(scalarType, []<typename scalar_t>() { return std::is_same_v<scalar_t, T>; });
return std::is_same_v<scalar_t, T>;
});
} }
#pragma nv_diagnostic pop #pragma nv_diagnostic pop
template<typename F, int ...N> template<typename F, int... N>
inline auto dispatchVal(int val, std::integer_sequence<int, N...>, F &&func) { inline auto dispatchVal(int val, std::integer_sequence<int, N...>, F &&func) {
auto call = [&]<int i>() { auto call = [&]<int i>() {
if (val == i) { if (val == i) {
...@@ -82,5 +81,4 @@ inline auto dispatchBool(bool val, F &&func) { ...@@ -82,5 +81,4 @@ inline auto dispatchBool(bool val, F &&func) {
} }
} }
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) dispatchFloat(TYPE, [&]<typename scalar_t>() { __VA_ARGS__(); });
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) dispatchFloat(TYPE, [&]<typename scalar_t>() { __VA_ARGS__(); });
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment