Unverified Commit 37a27712 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #340 from mit-han-lab/dev

feat: support PuLID, Double FBCache and TeaCache; better linter
parents c1d6fc84 760ab022
......@@ -20,7 +20,8 @@ public:
// class SiluAndMulQuant {
// public:
// static Tensor forward(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer, bool act_sum) {
// static Tensor forward(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor
// quantized_sum_buffer, bool act_sum) {
// if (act_sum) {
// return forward_with_act_sum(x, quantized_mlp_act_buffer, quantized_scale_buffer, quantized_sum_buffer);
// } else {
......@@ -28,6 +29,7 @@ public:
// }
// }
// private:
// static Tensor forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer);
// static Tensor forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer);
// };
\ No newline at end of file
// static Tensor forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer,
// Tensor quantized_sum_buffer); static Tensor forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor
// quantized_scale_buffer, Tensor quantized_sum_buffer);
// };
......@@ -25,7 +25,7 @@
class CUDAError : public std::runtime_error {
public:
CUDAError(cudaError_t errorCode, std::source_location location)
CUDAError(cudaError_t errorCode, std::source_location location)
: std::runtime_error(format(errorCode, location)), errorCode(errorCode), location(location) {}
public:
......@@ -34,12 +34,13 @@ public:
private:
static std::string format(cudaError_t errorCode, std::source_location location) {
return spdlog::fmt_lib::format("CUDA error: {} (at {}:{})",
cudaGetErrorString(errorCode), location.file_name(), location.line());
return spdlog::fmt_lib::format(
"CUDA error: {} (at {}:{})", cudaGetErrorString(errorCode), location.file_name(), location.line());
}
};
inline cudaError_t checkCUDA(cudaError_t retValue, const std::source_location location = std::source_location::current()) {
inline cudaError_t checkCUDA(cudaError_t retValue,
const std::source_location location = std::source_location::current()) {
if (retValue != cudaSuccess) {
(void)cudaGetLastError();
throw CUDAError(retValue, location);
......@@ -47,10 +48,11 @@ inline cudaError_t checkCUDA(cudaError_t retValue, const std::source_location lo
return retValue;
}
inline cublasStatus_t checkCUBLAS(cublasStatus_t retValue, const std::source_location location = std::source_location::current()) {
inline cublasStatus_t checkCUBLAS(cublasStatus_t retValue,
const std::source_location location = std::source_location::current()) {
if (retValue != CUBLAS_STATUS_SUCCESS) {
throw std::runtime_error(spdlog::fmt_lib::format("CUBLAS error: {} (at {}:{})",
cublasGetStatusString(retValue), location.file_name(), location.line()));
throw std::runtime_error(spdlog::fmt_lib::format(
"CUBLAS error: {} (at {}:{})", cublasGetStatusString(retValue), location.file_name(), location.line()));
}
return retValue;
}
......@@ -71,8 +73,8 @@ struct CUDAStreamContext {
stackCUDAStreams.push(stream);
}
CUDAStreamContext(const CUDAStreamContext &) = delete;
CUDAStreamContext(CUDAStreamContext &&) = delete;
CUDAStreamContext(CUDAStreamContext &&) = delete;
~CUDAStreamContext() {
assert(stackCUDAStreams.top() == stream);
stackCUDAStreams.pop();
......@@ -86,7 +88,7 @@ struct CUDAStreamWrapper {
checkCUDA(cudaStreamCreate(&stream));
}
CUDAStreamWrapper(const CUDAStreamWrapper &) = delete;
CUDAStreamWrapper(CUDAStreamWrapper &&) = delete;
CUDAStreamWrapper(CUDAStreamWrapper &&) = delete;
~CUDAStreamWrapper() {
checkCUDA(cudaStreamDestroy(stream));
......@@ -100,14 +102,13 @@ struct CUDAEventWrapper {
checkCUDA(cudaEventCreateWithFlags(&event, flags));
}
CUDAEventWrapper(const CUDAEventWrapper &) = delete;
CUDAEventWrapper(CUDAEventWrapper &&) = delete;
CUDAEventWrapper(CUDAEventWrapper &&) = delete;
~CUDAEventWrapper() {
checkCUDA(cudaEventDestroy(event));
}
};
/**
* 1. hold one when entered from external code (set `device` to -1 to avoid device change)
* 2. hold one when switching device
......@@ -121,7 +122,7 @@ public:
// previous context is reset on => external code may be executed, reset
currentDeviceCache = -1;
}
ctxs.push(this);
lastDevice = getDevice();
if (device >= 0) {
......@@ -134,7 +135,7 @@ public:
}
}
CUDADeviceContext(const CUDADeviceContext &) = delete;
CUDADeviceContext(CUDADeviceContext &&) = delete;
CUDADeviceContext(CUDADeviceContext &&) = delete;
~CUDADeviceContext() {
if (disableCache) {
......@@ -148,7 +149,8 @@ public:
if (cacheDisabled()) {
// ctxs.empty() => we are about to return to external code, reset cache
// otherwise => we are a nested context in a previous context with reset on, we might continue to execute external code, reset
// otherwise => we are a nested context in a previous context with reset on, we might continue to execute
// external code, reset
currentDeviceCache = -1;
}
}
......@@ -156,7 +158,6 @@ public:
const bool disableCache;
int lastDevice;
public:
static int getDevice() {
int idx = -1;
......@@ -168,6 +169,7 @@ public:
currentDeviceCache = cacheDisabled() ? -1 : idx;
return idx;
}
private:
static void setDevice(int idx) {
// TODO: deal with stream when switching device
......@@ -207,11 +209,11 @@ constexpr T ceilDiv(T a, T b) {
template<typename T>
constexpr int log2Up(T value) {
if (value <= 0)
return 0;
if (value == 1)
return 0;
return log2Up((value + 1) / 2) + 1;
if (value <= 0)
return 0;
if (value == 1)
return 0;
return log2Up((value + 1) / 2) + 1;
}
struct CUBLASWrapper {
......@@ -220,7 +222,7 @@ struct CUBLASWrapper {
CUBLASWrapper() {
checkCUBLAS(cublasCreate(&handle));
}
CUBLASWrapper(CUBLASWrapper &&) = delete;
CUBLASWrapper(CUBLASWrapper &&) = delete;
CUBLASWrapper(const CUBLASWrapper &&) = delete;
~CUBLASWrapper() {
if (handle) {
......@@ -236,6 +238,6 @@ inline std::shared_ptr<CUBLASWrapper> getCUBLAS() {
return result;
}
result = std::make_shared<CUBLASWrapper>();
inst = result;
inst = result;
return result;
}
\ No newline at end of file
}
......@@ -9,8 +9,8 @@ public:
ctxs.insert(this);
}
DebugContext(const DebugContext &) = delete;
DebugContext(DebugContext &&) = delete;
DebugContext(DebugContext &&) = delete;
~DebugContext() {
ctxs.erase(this);
}
......@@ -19,4 +19,3 @@ public:
static inline thread_local std::set<DebugContext *> ctxs;
};
......@@ -22,20 +22,20 @@ Tensor from_torch(at::Tensor input) {
}
static const std::map<at::ScalarType, Tensor::ScalarType> mapType = {
{ at::ScalarType::Char, Tensor::INT8 },
{ at::ScalarType::Byte, Tensor::INT8 },
{ at::ScalarType::Int, Tensor::INT32 },
{ at::ScalarType::Long, Tensor::INT64 },
{ at::ScalarType::Float, Tensor::FP32 },
{ at::ScalarType::Half, Tensor::FP16 },
{ at::ScalarType::BFloat16, Tensor::BF16 },
{ at::ScalarType::Short, Tensor::INT16 },
{ at::ScalarType::Float8_e4m3fn, Tensor::FP8_E4M3 },
{ at::ScalarType::Float8_e5m2, Tensor::FP8_E5M2 },
{at::ScalarType::Char, Tensor::INT8},
{at::ScalarType::Byte, Tensor::INT8},
{at::ScalarType::Int, Tensor::INT32},
{at::ScalarType::Long, Tensor::INT64},
{at::ScalarType::Float, Tensor::FP32},
{at::ScalarType::Half, Tensor::FP16},
{at::ScalarType::BFloat16, Tensor::BF16},
{at::ScalarType::Short, Tensor::INT16},
{at::ScalarType::Float8_e4m3fn, Tensor::FP8_E4M3},
{at::ScalarType::Float8_e5m2, Tensor::FP8_E5M2},
};
result.scalarType = mapType.at(input.scalar_type());
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
......@@ -51,15 +51,15 @@ at::Tensor to_torch(Tensor input) {
}
static const std::map<Tensor::ScalarType, at::ScalarType> mapType = {
{ Tensor::INT8, at::ScalarType::Byte },
{ Tensor::INT32, at::ScalarType::Int },
{ Tensor::INT64, at::ScalarType::Long },
{ Tensor::FP32, at::ScalarType::Float },
{ Tensor::FP16, at::ScalarType::Half },
{ Tensor::BF16, at::ScalarType::BFloat16 },
{ Tensor::INT16, at::ScalarType::Short },
{ Tensor::FP8_E4M3, at::ScalarType::Float8_e4m3fn },
{ Tensor::FP8_E5M2, at::ScalarType::Float8_e5m2 },
{Tensor::INT8, at::ScalarType::Byte},
{Tensor::INT32, at::ScalarType::Int},
{Tensor::INT64, at::ScalarType::Long},
{Tensor::FP32, at::ScalarType::Float},
{Tensor::FP16, at::ScalarType::Half},
{Tensor::BF16, at::ScalarType::BFloat16},
{Tensor::INT16, at::ScalarType::Short},
{Tensor::FP8_E4M3, at::ScalarType::Float8_e4m3fn},
{Tensor::FP8_E5M2, at::ScalarType::Float8_e5m2},
};
c10::TensorOptions opts(mapType.at(input.scalar_type()));
......@@ -82,4 +82,4 @@ TorchOpContext::TorchOpContext() {
TorchOpContext::~TorchOpContext() {
assert(stackCUDAStreams.top() == at::cuda::getCurrentCUDAStream().stream());
stackCUDAStreams.pop();
}
\ No newline at end of file
}
......@@ -8,15 +8,16 @@
class BufferTorchTensor : public Buffer {
public:
BufferTorchTensor(at::Tensor tensor) : tensor(std::move(tensor)) {
this->size = this->tensor.numel() * this->tensor.itemsize();
this->ptr = this->tensor.data_ptr();
this->size = this->tensor.numel() * this->tensor.itemsize();
this->ptr = this->tensor.data_ptr();
this->device.type = this->tensor.is_cuda() ? Device::CUDA : Device::CPU;
this->device.idx = this->tensor.get_device();
this->device.idx = this->tensor.get_device();
}
virtual bool isAsyncBuffer() override {
// TODO: figure out how torch manages memory
return this->device.type == Device::CUDA;
}
private:
at::Tensor tensor;
};
......@@ -25,7 +26,7 @@ class TorchOpContext {
public:
TorchOpContext();
TorchOpContext(const TorchOpContext &) = delete;
TorchOpContext(TorchOpContext &&) = delete;
TorchOpContext(TorchOpContext &&) = delete;
~TorchOpContext();
};
......@@ -48,4 +49,4 @@ public:
private:
std::map<std::string, at::Tensor> storage;
};
\ No newline at end of file
};
......@@ -3,91 +3,84 @@
#include "dispatch_utils.h"
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"activation_kernel", \
[&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
});
void silu_and_mul(
Tensor& out, // [..., d]
Tensor& input) // [..., 2 * d]
void silu_and_mul(Tensor &out, // [..., d]
Tensor &input) // [..., 2 * d]
{
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
// dispatchFloat(input.scalar_type(), [&]<typename scalar_t>() {
// vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
// out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
// });
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] {
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
});
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
// dispatchFloat(input.scalar_type(), [&]<typename scalar_t>() {
// vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
// out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
// });
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] {
vllm::silu_and_mul_kernel<scalar_t>
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
});
}
void invoke_dequant_silu_and_mul_quant(
Tensor &out, // [..., d]
Tensor &input, // [..., 2 * d]
const float scale_gate, const float scale_up, const float scale_out) {
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float, false><<<grid, block, 0, stream>>>(
out.data_ptr<int8_t>(), input.data_ptr<int32_t>(), d, scale_gate,
scale_up, scale_out);
void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
Tensor &input, // [..., 2 * d]
const float scale_gate,
const float scale_up,
const float scale_out) {
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float, false><<<grid, block, 0, stream>>>(
out.data_ptr<int8_t>(), input.data_ptr<int32_t>(), d, scale_gate, scale_up, scale_out);
}
void invoke_dequant_silu_and_mul_quant(
Tensor &out, // [..., d]
Tensor &input, // [..., 2 * d]
const float scale_gate, const float scale_up,
Tensor &scale_out, // [num_tokens]
Tensor &tmp // [..., d]
void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
Tensor &input, // [..., 2 * d]
const float scale_gate,
const float scale_up,
Tensor &scale_out, // [num_tokens]
Tensor &tmp // [..., d]
) {
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float*, true><<<grid, block, 0, stream>>>(
out.data_ptr<int8_t>(), input.data_ptr<int32_t>(),
d, scale_gate, scale_up, scale_out.data_ptr<float>(), tmp.data_ptr<float>());
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float *, true><<<grid, block, 0, stream>>>(out.data_ptr<int8_t>(),
input.data_ptr<int32_t>(),
d,
scale_gate,
scale_up,
scale_out.data_ptr<float>(),
tmp.data_ptr<float>());
}
void silu(
Tensor& out, // [..., d]
Tensor& input) // [..., d]
void silu(Tensor &out, // [..., d]
Tensor &input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::silu);
LAUNCH_ACTIVATION_KERNEL(vllm::silu);
}
void gelu_new(
Tensor& out, // [..., d]
Tensor& input) // [..., d]
void gelu_new(Tensor &out, // [..., d]
Tensor &input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}
void gelu_fast(
Tensor& out, // [..., d]
Tensor& input) // [..., d]
void gelu_fast(Tensor &out, // [..., d]
Tensor &input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
\ No newline at end of file
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
......@@ -3,9 +3,8 @@
#include "common.h"
#include "Tensor.h"
void silu(
Tensor& out, // [..., d]
Tensor& input);
void silu(Tensor &out, // [..., d]
Tensor &input);
void silu_and_mul(Tensor &out, // [..., d]
Tensor &input); // [..., 2 * d]
......@@ -25,5 +24,5 @@ void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
const float scale_gate,
const float scale_up,
Tensor &scale_out, // [num_tokens]
Tensor &tmp // [num_tokens, d]
);
\ No newline at end of file
Tensor &tmp // [num_tokens, d]
);
......@@ -3,116 +3,104 @@
namespace vllm {
template <typename T> __device__ __forceinline__ T silu(const T &x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
template<typename T>
__device__ __forceinline__ T silu(const T &x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}
template<typename scalar_t>
__global__ void silu_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2 * d]
const int d) {
__global__ void silu_and_mul_kernel(scalar_t *__restrict__ out, // [..., d]
const scalar_t *__restrict__ input, // [..., 2 * d]
const int d) {
const int token_idx = blockIdx.x;
const int64_t token_idx_d = token_idx * int64_t(d);
const int64_t token_idx_2d = token_idx_d * 2;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx_2d + idx]);
const scalar_t y = __ldg(&input[token_idx_2d + d + idx]);
out[token_idx_d + idx] = silu(x) * y;
}
const int token_idx = blockIdx.x;
const int64_t token_idx_d = token_idx * int64_t(d);
const int64_t token_idx_2d = token_idx_d * 2;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx_2d + idx]);
const scalar_t y = __ldg(&input[token_idx_2d + d + idx]);
out[token_idx_d + idx] = silu(x) * y;
}
}
// dequant int32 input, apply silu and mul, then per token quant to int8
template <typename scale_type, bool use_per_token_quant>
__global__ void dequant_silu_and_mul_quant_kernel(
int8_t *__restrict__ out, // [..., d]
const int32_t *__restrict__ input, // [..., 2 * d]
const int d, const float scale_gate, const float scale_up,
scale_type scale_out, // [num_tokens]
float *__restrict__ tmp = nullptr // [num_tokens, d]
template<typename scale_type, bool use_per_token_quant>
__global__ void dequant_silu_and_mul_quant_kernel(int8_t *__restrict__ out, // [..., d]
const int32_t *__restrict__ input, // [..., 2 * d]
const int d,
const float scale_gate,
const float scale_up,
scale_type scale_out, // [num_tokens]
float *__restrict__ tmp = nullptr // [num_tokens, d]
) {
const int token_idx = blockIdx.x;
if constexpr (use_per_token_quant) {
float amax_val = 0.0f;
const float zero = 0.0f;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x =
(float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y =
(float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
float t = silu(x) * y;
tmp[token_idx * d + idx] = t;
t = t > zero ? t : -t;
if (t > amax_val)
amax_val = t;
}
__shared__ float s_amax;
const float block_amax_val = blockReduceMax(amax_val);
if (threadIdx.x == 0) {
s_amax = block_amax_val;
scale_out[token_idx] = block_amax_val / 127.0f;
const int token_idx = blockIdx.x;
if constexpr (use_per_token_quant) {
float amax_val = 0.0f;
const float zero = 0.0f;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
float t = silu(x) * y;
tmp[token_idx * d + idx] = t;
t = t > zero ? t : -t;
if (t > amax_val)
amax_val = t;
}
__shared__ float s_amax;
const float block_amax_val = blockReduceMax(amax_val);
if (threadIdx.x == 0) {
s_amax = block_amax_val;
scale_out[token_idx] = block_amax_val / 127.0f;
}
__syncthreads();
float tmp_scale = 127.0f / s_amax;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
out[token_idx * d + idx] = float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]);
}
} else {
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out);
}
}
__syncthreads();
float tmp_scale = 127.0f / s_amax;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
out[token_idx * d + idx] =
float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]);
}
} else {
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x =
(float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y =
(float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out);
}
}
}
} // namespace vllm
namespace vllm {
// Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t &)>
__global__ void activation_kernel(scalar_t *__restrict__ out, // [..., d]
const scalar_t *__restrict__ input, // [..., d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
} // namespace vllm
namespace vllm {
template <typename T> __device__ __forceinline__ T gelu_new_kernel(const T &x) {
const float x3 = (float)(x * x * x);
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T)0.5) * x * (((T)1.0) + t);
template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T &x) {
const float x3 = (float)(x * x * x);
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T)0.5) * x * (((T)1.0) + t);
}
template <typename T>
template<typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T &x) {
const float f = (float)x;
const T t =
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
const float f = (float)x;
const T t = (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
}
} // namespace vllm
/*
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
Modified from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
......@@ -13,16 +14,15 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
#include <cuda_fp16.h>
#include <cstdint>
__forceinline__ __device__
void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
__forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
......@@ -75,25 +75,26 @@ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
}
__forceinline__ __device__
void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
__forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
// dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result);
// *reinterpret_cast<__nv_bfloat162 *>(&result->x) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->x));
// *reinterpret_cast<__nv_bfloat162 *>(&result->y) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->y));
// *reinterpret_cast<__nv_bfloat162 *>(&result->z) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
// *reinterpret_cast<__nv_bfloat162 *>(&result->w) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->w));
// *reinterpret_cast<__nv_bfloat162 *>(&result->x) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2
// *>(&result->x)); *reinterpret_cast<__nv_bfloat162 *>(&result->y) =
// cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->y)); *reinterpret_cast<__nv_bfloat162
// *>(&result->z) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
// *reinterpret_cast<__nv_bfloat162 *>(&result->w) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2
// *>(&result->w));
// return;
// uint4 result;
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
// Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
......@@ -127,4 +128,4 @@ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_67
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS));
}
\ No newline at end of file
}
......@@ -2,14 +2,13 @@
#include <cuda_bf16.h>
#include "semaphore.h"
#include "gemm_awq.h"
//#include "../../../nunchaku/csrc/quantization/dequantize.cuh"
// #include "../../../nunchaku/csrc/quantization/dequantize.cuh"
#include "dequantize.cuh"
#include <stdio.h>
#include "../dispatch_utils.h"
//#include "../../../nunchaku/csrc/utils.cuh"
// #include "../../../nunchaku/csrc/utils.cuh"
#include "../utils.cuh"
#include <cuda_pipeline_primitives.h>
#define kInterleave 4
......@@ -29,1141 +28,1342 @@
#define L2_CACHEHINT(size)
#endif
#define KERNEL_LAUNCH_CODE \
int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \
Tensor _semaphores = Tensor::empty({num_mn_tiles}, Tensor::INT32, _in_feats.device()); \
auto semaphores = reinterpret_cast<int *>(_semaphores.data_ptr<int>()); \
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * sizeof(f16_t); \
if (kSmemByteSize >= 99 * 1024) \
{ \
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); \
return _out_feats; \
} \
int j_factors1 = num_out_channels / CTA_N / 1; \
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = gemm_w4a16_T1<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>; \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>( \
in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);
template <int N>
__inline__ __host__ __device__ int get_log_tile(int n)
{
if (N >= 8 && n >= 6)
return 3;
else if (N >= 4 && n >= 3)
return 2;
else if (N >= 2 && n >= 2)
return 1;
else
return 0;
#define KERNEL_LAUNCH_CODE \
int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \
Tensor _semaphores = Tensor::empty({num_mn_tiles}, Tensor::INT32, _in_feats.device()); \
auto semaphores = reinterpret_cast<int *>(_semaphores.data_ptr<int>()); \
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int SCALES_SMEM_SIZE = \
(G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
constexpr int kSmemByteSize = \
(CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * \
sizeof(f16_t); \
if (kSmemByteSize >= 99 * 1024) { \
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); \
return _out_feats; \
} \
int j_factors1 = num_out_channels / CTA_N / 1; \
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = gemm_w4a16_T1<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>; \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>( \
in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);
template<int N>
__inline__ __host__ __device__ int get_log_tile(int n) {
if (N >= 8 && n >= 6)
return 3;
else if (N >= 4 && n >= 3)
return 2;
else if (N >= 2 && n >= 2)
return 1;
else
return 0;
}
__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile)
{
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) {
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
}
template <int SLICES, int NUM_WARPS_MN>
__device__ void sync_slice(int slice_id)
{
if constexpr (SLICES == 1)
{
__syncthreads();
}
else
{
constexpr int SLICE_GROUP = (SLICES + 7) / 8;
constexpr uint32_t num_threads = NUM_WARPS_MN * WARP_SIZE;
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
}
template<int SLICES, int NUM_WARPS_MN>
__device__ void sync_slice(int slice_id) {
if constexpr (SLICES == 1) {
__syncthreads();
} else {
constexpr int SLICE_GROUP = (SLICES + 7) / 8;
constexpr uint32_t num_threads = NUM_WARPS_MN * WARP_SIZE;
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
}
}
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr)
{
uint32_t smem_int_ptr;
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) {
uint32_t smem_int_ptr;
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_int_ptr)
: "l"(ptr));
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_int_ptr)
: "l"(ptr));
return smem_int_ptr;
return smem_int_ptr;
}
template <typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr)
{
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
}
template <typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr)
{
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
}
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask)
{
const int cp_size = 16;
asm volatile("{"
" .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
"}" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) {
const int cp_size = 16;
asm volatile("{"
" .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
"}" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
}
template <typename f16_t>
template<typename f16_t>
__device__ __inline__ void mma_m16n8k16(float *C_warp, f16_t *A_shared_warp, f16_t *B_shared_warp);
template <>
__device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp, half *B_shared_warp)
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3]));
template<>
__device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp, half *B_shared_warp) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]),
"r"(((unsigned *)A_shared_warp)[1]),
"r"(((unsigned *)A_shared_warp)[2]),
"r"(((unsigned *)A_shared_warp)[3]),
"r"(((unsigned *)B_shared_warp)[0]),
"r"(((unsigned *)B_shared_warp)[1]),
"f"(((float *)C_warp)[0]),
"f"(((float *)C_warp)[1]),
"f"(((float *)C_warp)[2]),
"f"(((float *)C_warp)[3]));
}
template <>
__device__ __inline__ void mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp)
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3]));
template<>
__device__ __inline__ void
mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]),
"r"(((unsigned *)A_shared_warp)[1]),
"r"(((unsigned *)A_shared_warp)[2]),
"r"(((unsigned *)A_shared_warp)[3]),
"r"(((unsigned *)B_shared_warp)[0]),
"r"(((unsigned *)B_shared_warp)[1]),
"f"(((float *)C_warp)[0]),
"f"(((float *)C_warp)[1]),
"f"(((float *)C_warp)[2]),
"f"(((float *)C_warp)[3]));
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A(f16_t *src, f16_t *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A(f16_t *src,
f16_t *dst,
int global_nrows,
int global_ncols,
int cta_offset_m,
int cta_offset_n,
int cta_offset_k,
int global_iter_k,
int shared_iter_k,
bool mask) {
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K + cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
}
else
{
if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
uint4 *src_ptr =
(uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K +
cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols +
// threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row)
// * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) *
// PACK_SIZE);
if constexpr (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
} else {
if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B(f16_t *src, f16_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B(f16_t *src,
f16_t *dst,
int global_ncols,
int cta_offset_m,
int cta_offset_n,
int cta_offset_k,
int global_iter_k,
int shared_iter_k,
bool mask) {
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col = (threadIdx.x % threads_per_row);
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
}
else
{
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col = (threadIdx.x % threads_per_row);
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols +
ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k);
if constexpr (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
} else {
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales(f16_t *src, f16_t *dst, f16_t *src_z, f16_t *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int LD_AMOUNT = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G;
constexpr int threads_needed = LD_AMOUNT / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = LD_AMOUNT / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = (cta_offset_k + global_iter_k * CTA_K) / G;
void *dst_ptr = (void *)(dst + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE);
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
}
else
{
if (local_mask)
{
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales(f16_t *src,
f16_t *dst,
f16_t *src_z,
f16_t *dst_z,
int global_ncols,
int cta_offset_m,
int cta_offset_n,
int cta_offset_k,
int global_iter_k,
int shared_iter_k,
bool mask) {
constexpr int LD_AMOUNT = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G;
constexpr int threads_needed = LD_AMOUNT / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = LD_AMOUNT / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = (cta_offset_k + global_iter_k * CTA_K) / G;
void *dst_ptr =
(void *)(dst + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr =
(uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols +
(threadIdx.x % threads_per_row) * PACK_SIZE);
void *dst_ptr_z =
(void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr_z =
(uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols +
(threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
} else {
if (local_mask) {
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void share_to_reg_one_stage_A(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void
share_to_reg_one_stage_A(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) {
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8 + warp_offset_k;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8 + warp_offset_k;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B(f16_t *src, f16_t *src_scales, f16_t *src_zeros, f16_t *dst, f16_t *dst_fp16, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1)
{
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix)
{
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B(f16_t *src,
f16_t *src_scales,
f16_t *src_zeros,
f16_t *dst,
f16_t *dst_fp16,
int warp_offset_m,
int warp_offset_n,
int warp_offset_k,
int k_0_1) {
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix) {
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled + warp_offset_k);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
void *addr_ptr =
(void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol +
k_0_1 * 16 + r * kSmemCol + c_swizzled + warp_offset_k);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
}
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
f16_t scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f16_t zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f162_t scale2 = f162f162(scale);
f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
f16_t scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) +
threadIdx.x / 4];
f16_t zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) +
threadIdx.x / 4];
f162_t scale2 = f162f162(scale);
f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8),
reinterpret_cast<uint4 *>(loaded));
#pragma unroll
for (int i = 0; i < 4; i++)
{
loaded[i] = __hfma2(loaded[i], scale2, zero2);
for (int i = 0; i < 4; i++) {
loaded[i] = __hfma2(loaded[i], scale2, zero2);
}
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
}
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G, int SPLITK>
__global__ void gemm_w4a16_T1(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_t *__restrict__ scales, f16_t *__restrict__ zeros, f16_t *__restrict__ C, int *__restrict__ semaphores, int M, int N, int K)
{
template<typename f16_t,
int CTA_M,
int CTA_N,
int CTA_K,
int WARP_M,
int WARP_N,
int WARP_K,
int STAGES,
int G,
int SPLITK>
__global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
f16_t *__restrict__ B,
f16_t *__restrict__ scales,
f16_t *__restrict__ zeros,
f16_t *__restrict__ C,
int *__restrict__ semaphores,
int M,
int N,
int K) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch();
return;
trap_unsupported_arch();
return;
#endif
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
constexpr int SLICES = CTA_K / WARP_K;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
constexpr int kSmemSizeScales = CTA_N * STAGES / scales_load_interval * scales_per_load;
constexpr int kSmemSizeZeros = CTA_N * STAGES / scales_load_interval * scales_per_load;
extern __shared__ half mem_shared[];
f16_t *A_shared = reinterpret_cast<f16_t *>(mem_shared);
f16_t *B_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA);
f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB);
f16_t *zeros_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
float *C_shared = reinterpret_cast<float *>(mem_shared);
f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int cta_offset_k = blockIdx_z * (K / SPLITK);
int warp_mn = threadIdx.y % NUM_WARPS_MN;
int slice_id = threadIdx.y / NUM_WARPS_MN;
int warp_offset_n = (warp_mn % (CTA_N / WARP_N)) * WARP_N;
int warp_offset_m = (warp_mn / (CTA_N / WARP_N)) * WARP_M;
int warp_offset_k = slice_id * WARP_K;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K / SPLITK;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
constexpr int SLICES = CTA_K / WARP_K;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
constexpr int kSmemSizeScales = CTA_N * STAGES / scales_load_interval * scales_per_load;
constexpr int kSmemSizeZeros = CTA_N * STAGES / scales_load_interval * scales_per_load;
extern __shared__ half mem_shared[];
f16_t *A_shared = reinterpret_cast<f16_t *>(mem_shared);
f16_t *B_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA);
f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB);
f16_t *zeros_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
float *C_shared = reinterpret_cast<float *>(mem_shared);
f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int cta_offset_k = blockIdx_z * (K / SPLITK);
int warp_mn = threadIdx.y % NUM_WARPS_MN;
int slice_id = threadIdx.y / NUM_WARPS_MN;
int warp_offset_n = (warp_mn % (CTA_N / WARP_N)) * WARP_N;
int warp_offset_m = (warp_mn / (CTA_N / WARP_N)) * WARP_M;
int warp_offset_k = slice_id * WARP_K;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K / SPLITK;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
#pragma unroll
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)
{
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true);
global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
zeros, zeros_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
N, cta_offset_m, cta_offset_n, cta_offset_k,
k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) {
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A,
A_shared +
k_0_0_ld * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
0,
true);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B,
B_shared +
k_0_0_ld * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
0,
true);
global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales,
scales_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
zeros,
zeros_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
N,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
0,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
{
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared,
scales_shared,
zeros_shared,
B_shared_warp_tmp_[0],
B_shared_warp_[0],
warp_offset_m,
warp_offset_n,
warp_offset_k,
0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) {
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
#pragma unroll
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
{
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0)
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
}
else
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
}
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
{
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
{
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1)
{
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2)
{
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2)
{
__syncthreads();
}
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
zeros, zeros_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
N, cta_offset_m, cta_offset_n, cta_offset_k,
k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
{
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) {
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage =
scales_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
zeros_shared_this_compute_stage =
zeros_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
A_shared_this_compute_stage,
A_shared_warp_[(iter_k + 1) % 2],
warp_offset_m,
warp_offset_n,
warp_offset_k,
(iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0) {
if (compute_stage % 2 == 1) {
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
warp_offset_k,
(iter_k + 1) % SHARED_K_ITERS);
} else {
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
warp_offset_k,
(iter_k + 1) % SHARED_K_ITERS);
}
} else {
if (compute_stage % 2 == 1) {
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
warp_offset_k,
(iter_k + 1) % SHARED_K_ITERS);
} else {
share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
warp_offset_k,
(iter_k + 1) % SHARED_K_ITERS);
}
}
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) {
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) {
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8,
A_shared_warp + i_0_3 * 8,
B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4,
A_shared_warp + i_0_3 * 8,
B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1) {
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
A,
A_shared + ld_stage * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
B,
B_shared + ld_stage * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2) {
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) {
__syncthreads();
}
global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
A,
A_shared + ld_stage * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k + 1,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
B,
B_shared + ld_stage * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k + 1,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales,
scales_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
zeros,
zeros_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
N,
cta_offset_m,
cta_offset_n,
cta_offset_k,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1) {
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
}
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
if constexpr (SLICES > 1)
{
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
if constexpr (SLICES > 1) {
#pragma unroll
for (int z = 0; z < SLICES; ++z)
{
if (slice_id == z)
{
for (int z = 0; z < SLICES; ++z) {
if (slice_id == z) {
#pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
#pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
#pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
{
if (z > 0)
{
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
}
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
};
}
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
if (z > 0) {
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] +=
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n +
ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
}
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 +
(local_id % 2) + (threadIdx.x % 4) * 2] =
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
};
}
}
}
__syncthreads();
}
}
__syncthreads();
}
if (slice_id == 0)
{
if (slice_id == 0) {
#pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
#pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
#pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
{
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
};
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] =
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 +
(local_id % 2) + (threadIdx.x % 4) * 2];
};
}
}
}
}
}
}
if (slice_id == 0)
{
Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x);
if (slice_id == 0) {
Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x);
if constexpr (SPLITK > 1)
{
semaphore.fetch();
}
if constexpr (SPLITK > 1) {
semaphore.fetch();
}
if (blockIdx_z != 0)
{
semaphore.wait(blockIdx_z);
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M)
{
f162_t *existing_psum_ptr = reinterpret_cast<f162_t *>(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2);
*existing_psum_ptr = __hadd2(
*existing_psum_ptr,
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(
C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id)));
if (blockIdx_z != 0) {
semaphore.wait(blockIdx_z);
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M +
((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M) {
f162_t *existing_psum_ptr = reinterpret_cast<f162_t *>(
C + write_row * N + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + (local_id / 4) * 8 +
(local_id % 2) + (threadIdx.x % 4) * 2);
*existing_psum_ptr =
__hadd2(*existing_psum_ptr,
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(
C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id)));
}
};
}
}
};
}
}
}
else
{
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M)
{
*reinterpret_cast<f162_t *>(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
} else {
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M +
((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M) {
*reinterpret_cast<f162_t *>(C + write_row * N + cta_offset_n + warp_offset_n +
ax1_0_1 * 16 + (local_id / 4) * 8 + (local_id % 2) +
(threadIdx.x % 4) * 2) =
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
}
};
}
}
};
}
}
}
if constexpr (SPLITK > 1)
{
if constexpr (SPLITK > 1) {
int lock = 0;
if (SPLITK == blockIdx_z + 1)
{
int lock = 0;
if (SPLITK == blockIdx_z + 1) {
lock = 0;
}
else
{
lock = blockIdx_z + 1;
}
semaphore.release(lock);
lock = 0;
} else {
lock = blockIdx_z + 1;
}
semaphore.release(lock);
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A_T2(f16_t *src, f16_t *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A_T2(f16_t *src,
f16_t *dst,
int global_nrows,
int global_ncols,
int cta_offset_m,
int cta_offset_n,
int global_iter_k,
int shared_iter_k,
bool mask) {
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
}
else
{
if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
uint4 *src_ptr =
(uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE +
global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n *
// global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols +
// (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K
// + (threadIdx.x % threads_per_row) * PACK_SIZE);
if constexpr (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
} else {
if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B_T2(f16_t *src, f16_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B_T2(f16_t *src,
f16_t *dst,
int global_ncols,
int cta_offset_m,
int cta_offset_n,
int global_iter_k,
int shared_iter_k,
bool mask) {
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col = (threadIdx.x % threads_per_row);
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
}
else
{
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col = (threadIdx.x % threads_per_row);
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols +
ld_row * global_ncols + ld_col * PACK_SIZE);
if constexpr (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
} else {
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales_T2(f16_t *src, f16_t *dst, f16_t *src_z, f16_t *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = global_iter_k * CTA_K / G;
void *dst_ptr = (void *)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
}
else
{
if (local_mask)
{
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales_T2(f16_t *src,
f16_t *dst,
f16_t *src_z,
f16_t *dst_z,
int global_ncols,
int cta_offset_m,
int cta_offset_n,
int global_iter_k,
int shared_iter_k,
bool mask) {
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = global_iter_k * CTA_K / G;
void *dst_ptr = (void *)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr_z =
(uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1) {
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
} else {
if (local_mask) {
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
}
}
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void share_to_reg_one_stage_A_T2(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int k_0_1)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void
share_to_reg_one_stage_A_T2(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int k_0_1) {
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B_T2(f16_t *src, f16_t *src_scales, f16_t *src_zeros, f16_t *dst, f16_t *dst_fp16, int warp_offset_m, int warp_offset_n, int k_0_1)
{
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix)
{
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B_T2(f16_t *src,
f16_t *src_scales,
f16_t *src_zeros,
f16_t *dst,
f16_t *dst_fp16,
int warp_offset_m,
int warp_offset_n,
int k_0_1) {
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix) {
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
void *addr_ptr =
(void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol +
k_0_1 * 16 + r * kSmemCol + c_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
}
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
f16_t scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f16_t zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f162_t scale2 = f162f162(scale);
f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
f16_t scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f16_t zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f162_t scale2 = f162f162(scale);
f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8),
reinterpret_cast<uint4 *>(loaded));
#pragma unroll
for (int i = 0; i < 4; i++)
{
loaded[i] = __hfma2(loaded[i], scale2, zero2);
for (int i = 0; i < 4; i++) {
loaded[i] = __hfma2(loaded[i], scale2, zero2);
}
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
}
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
}
}
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
__global__ void gemm_w4a16_T2(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_t *__restrict__ scales, f16_t *__restrict__ zeros, f16_t *__restrict__ C, int M, int N, int K)
{
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
__global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
f16_t *__restrict__ B,
f16_t *__restrict__ scales,
f16_t *__restrict__ zeros,
f16_t *__restrict__ C,
int M,
int N,
int K) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch();
return;
trap_unsupported_arch();
return;
#endif
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int kSmemSizeScales = CTA_N * STAGES / 2;
constexpr int kSmemSizeZeros = CTA_N * STAGES / 2;
constexpr int scales_load_interval = G / CTA_K;
extern __shared__ half mem_shared[];
f16_t *A_shared = reinterpret_cast<f16_t *>(mem_shared);
f16_t *B_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA);
f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB);
f16_t *zeros_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int warp_offset_m = (threadIdx.y % (CTA_M / WARP_M)) * WARP_M;
int warp_offset_n = (threadIdx.y / (CTA_M / WARP_M)) * WARP_N;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE; i++)
C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int kSmemSizeScales = CTA_N * STAGES / 2;
constexpr int kSmemSizeZeros = CTA_N * STAGES / 2;
constexpr int scales_load_interval = G / CTA_K;
extern __shared__ half mem_shared[];
f16_t *A_shared = reinterpret_cast<f16_t *>(mem_shared);
f16_t *B_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA);
f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB);
f16_t *zeros_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int warp_offset_m = (threadIdx.y % (CTA_M / WARP_M)) * WARP_M;
int warp_offset_n = (threadIdx.y / (CTA_M / WARP_M)) * WARP_N;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE; i++)
C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
#pragma unroll
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)
{
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_scales_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
zeros, zeros_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
{
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
{
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval) * CTA_N;
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0)
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
}
else
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
}
__syncthreads();
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
{
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
{
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1)
{
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2)
{
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2)
{
__syncthreads();
}
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) {
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_scales_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (ld_stage / scales_load_interval) * CTA_N,
zeros, zeros_shared + (ld_stage / scales_load_interval) * CTA_N,
N, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
scales,
scales_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
zeros,
zeros_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
N,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
0,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
{
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared,
scales_shared,
zeros_shared,
B_shared_warp_tmp_[0],
B_shared_warp_[0],
warp_offset_m,
warp_offset_n,
0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) {
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) {
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval) * CTA_N;
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
A_shared_this_compute_stage,
A_shared_warp_[(iter_k + 1) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0) {
if (compute_stage % 2 == 1) {
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
} else {
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
}
} else {
if (compute_stage % 2 == 1) {
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
} else {
share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage,
scales_shared_this_compute_stage,
zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0],
B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m,
warp_offset_n,
(iter_k + 1) % SHARED_K_ITERS);
}
}
__syncthreads();
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) {
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) {
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8,
A_shared_warp + i_0_3 * 8,
B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4,
A_shared_warp + i_0_3 * 8,
B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1) {
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
A,
A_shared + ld_stage * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
B,
B_shared + ld_stage * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2) {
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) {
__syncthreads();
}
global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
A,
A_shared + ld_stage * kSmemSizeAPerStage,
M,
K,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k + 1,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
B,
B_shared + ld_stage * kSmemSizeBPerStage,
K,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k + 1,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales,
scales_shared + (ld_stage / scales_load_interval) * CTA_N,
zeros,
zeros_shared + (ld_stage / scales_load_interval) * CTA_N,
N,
cta_offset_m,
cta_offset_n,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1) {
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
}
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M)
{
*reinterpret_cast<f162_t *>(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
int write_row =
cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M) {
*reinterpret_cast<f162_t *>(C + write_row * N + cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
}
};
}
};
}
}
}
Tensor awq_gemm_forward_cuda(
Tensor _in_feats,
Tensor _kernel,
Tensor _scales,
Tensor _zeros)
{
auto output_shape = _in_feats.shape.dataExtent;
output_shape.back() = _kernel.size(0) * kInterleave;
int num_in_feats = _in_feats.numel() / _in_feats.size(-1);
int num_in_channels = _in_feats.size(-1);
auto options =
Tensor::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
auto options_int =
Tensor::TensorOptions().dtype(Tensor::INT32).device(_in_feats.device());
Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
int num_out_feats = _out_feats.numel() / _out_feats.size(-1);
int num_out_channels = _out_feats.size(-1);
if (_in_feats.scalar_type() == Tensor::FP16)
{
using f16_t = half;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
auto scales = reinterpret_cast<f16_t *>(_scales.data_ptr());
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr());
auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());
if (num_out_feats <= 32)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 2;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 64)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 128)
{
constexpr int G = 128;
constexpr int CTA_M = 32;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 192)
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(f16_t);
if (kSmemByteSize >= 99 * 1024)
{
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
return _out_feats;
}
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
}
else if (_in_feats.scalar_type() == Tensor::BF16)
{
using f16_t = __nv_bfloat16;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
auto scales = reinterpret_cast<f16_t *>(_scales.data_ptr());
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr());
auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());
if (num_out_feats <= 32)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 2;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 64)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 128)
{
constexpr int G = 128;
constexpr int CTA_M = 32;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 192)
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(f16_t);
if (kSmemByteSize >= 99 * 1024)
{
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
return _out_feats;
}
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, Tensor _zeros) {
auto output_shape = _in_feats.shape.dataExtent;
output_shape.back() = _kernel.size(0) * kInterleave;
int num_in_feats = _in_feats.numel() / _in_feats.size(-1);
int num_in_channels = _in_feats.size(-1);
auto options = Tensor::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
auto options_int = Tensor::TensorOptions().dtype(Tensor::INT32).device(_in_feats.device());
Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
int num_out_feats = _out_feats.numel() / _out_feats.size(-1);
int num_out_channels = _out_feats.size(-1);
if (_in_feats.scalar_type() == Tensor::FP16) {
using f16_t = half;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
auto scales = reinterpret_cast<f16_t *>(_scales.data_ptr());
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr());
auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());
if (num_out_feats <= 32) {
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 2;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 64) {
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 128) {
constexpr int G = 128;
constexpr int CTA_M = 32;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 192) {
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
} else {
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int kSmemByteSize =
(CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES *
sizeof(f16_t);
if (kSmemByteSize >= 99 * 1024) {
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
return _out_feats;
}
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
} else if (_in_feats.scalar_type() == Tensor::BF16) {
using f16_t = __nv_bfloat16;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
auto scales = reinterpret_cast<f16_t *>(_scales.data_ptr());
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr());
auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());
if (num_out_feats <= 32) {
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 2;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 64) {
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 128) {
constexpr int G = 128;
constexpr int CTA_M = 32;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
} else if (num_out_feats <= 192) {
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
} else {
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int kSmemByteSize =
(CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES *
sizeof(f16_t);
if (kSmemByteSize >= 99 * 1024) {
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
return _out_feats;
}
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
} else {
throw std::runtime_error("Unsupported input type");
}
}
else
{
throw std::runtime_error("Unsupported input type");
}
return _out_feats;
}
\ No newline at end of file
return _out_feats;
}
......@@ -3,9 +3,4 @@
#include "common.h"
#include "Tensor.h"
Tensor awq_gemm_forward_cuda(
Tensor _in_feats,
Tensor _kernel,
Tensor _scales,
Tensor _zeros);
Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, Tensor _zeros);
/*
* Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* Modified from NVIDIA
* [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -39,106 +40,98 @@
#define MEM_ACCESS_SIZE 128
// Reduce sum within the warp using the tree reduction algorithm.
template <typename float_t, int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(float_t* psum, float (*out_smem)[Num * 4])
{
// kInterleave = 4
float fpsum[Num];
#pragma unroll
for (int i = 0; i < Num; ++i)
{
fpsum[i] = static_cast<float>(psum[i]);
}
#pragma unroll
for (int i = 0; i < Num; ++i)
{
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
}
__syncthreads();
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
if (lane == 0 || lane == 2 || lane == 4 || lane == 6)
{
#pragma unroll
for (int i = 0; i < Num; ++i)
{
out_smem[warp][i * 4 + lane / 2] = fpsum[i];
}
}
__syncthreads();
template<typename float_t, int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(float_t *psum, float (*out_smem)[Num * 4]) {
// kInterleave = 4
float fpsum[Num];
#pragma unroll
for (int i = 0; i < Num; ++i) {
fpsum[i] = static_cast<float>(psum[i]);
}
#pragma unroll
for (int i = 0; i < Num; ++i) {
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
}
__syncthreads();
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
if (lane == 0 || lane == 2 || lane == 4 || lane == 6) {
#pragma unroll
for (int i = 0; i < Num; ++i) {
out_smem[warp][i * 4 + lane / 2] = fpsum[i];
}
}
__syncthreads();
};
__device__ __forceinline__ int make_divisible(int c, int divisor){
return (c + divisor - 1) / divisor;
__device__ __forceinline__ int make_divisible(int c, int divisor) {
return (c + divisor - 1) / divisor;
}
template<typename half_t>
__device__ __forceinline__
packed_as<half_t, 2>::type half2half2(half_t x);
template<typename half_t>
__device__ __forceinline__ packed_as<half_t, 2>::type half2half2(half_t x);
template<>
__device__ __forceinline__
packed_as<half, 2>::type half2half2<half>(half x) {
template<>
__device__ __forceinline__ packed_as<half, 2>::type half2half2<half>(half x) {
return __half2half2(x);
}
template<>
__device__ __forceinline__
packed_as<__nv_bfloat16, 2>::type half2half2<__nv_bfloat16>(__nv_bfloat16 x) {
template<>
__device__ __forceinline__ packed_as<__nv_bfloat16, 2>::type half2half2<__nv_bfloat16>(__nv_bfloat16 x) {
return __bfloat162bfloat162(x);
}
template<typename T>
__device__ __forceinline__
float2 half22float2(T val);
__device__ __forceinline__ float2 half22float2(T val);
template<>
__device__ __forceinline__
float2 half22float2<half2>(half2 val) {
__device__ __forceinline__ float2 half22float2<half2>(half2 val) {
return __half22float2(val);
}
template<>
__device__ __forceinline__
float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
__device__ __forceinline__ float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
return __bfloat1622float2(val);
}
template <typename half_t, int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(
const half_t* inputs, const uint32_t* weight, const half_t* scales, const half_t* zeros, half_t* outputs,
const int IC, const int OC)
{
template<typename half_t, int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(const half_t *inputs,
const uint32_t *weight,
const half_t *scales,
const half_t *zeros,
half_t *outputs,
const int IC,
const int OC) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
if constexpr(std::is_same_v<half_t, __nv_bfloat16>) {
if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
trap_unsupported_arch();
return;
}
#endif
using half2_t = typename packed_as<half_t, 2>::type;
using accum_t = float;
using half2_t = typename packed_as<half_t, 2>::type;
using accum_t = float;
using accum2_t = typename packed_as<accum_t, 2>::type;
const int kStride = 64;
const int kElemsPerThread = MEM_ACCESS_SIZE / 4;
const int kStride = 64;
const int kElemsPerThread = MEM_ACCESS_SIZE / 4;
const int kThreadsNumPerTile = kStride / kElemsPerThread;
// assert(MEM_ACCESS_SIZE == 128);
// static constexpr int kShuffleSize = 32;
static constexpr int kShuffleBasicTile = 2;
static constexpr int kShuffleContinous = 4;
static constexpr int kShuffleStrided = 4;
static constexpr int kShuffleStrided = 4;
constexpr int Num = NPerBlock * Batch;
constexpr int Num = NPerBlock * Batch;
constexpr int kInterleave = 4;
alignas(16) half_t local_inputs[kElemsPerThread];
alignas(16) uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
alignas(16) half_t half_weight_buffer[kElemsPerThread];
alignas(16) half_t half_weight_buffer[kElemsPerThread];
alignas(16) half_t dequantized_weight[kElemsPerThread * NPerBlock];
alignas(16) half_t local_scale[NPerBlock];
alignas(16) half_t local_scaled_zeros[NPerBlock];
......@@ -146,7 +139,7 @@ __global__ void gemv_kernel(
accum_t psum[Num];
for (int i = 0; i < Num; ++i)
psum[i] = static_cast<accum_t>(0.f);
// extern __shared__ uint8_t shmem[];
// float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
......@@ -154,80 +147,67 @@ __global__ void gemv_kernel(
const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride
+ (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride +
(threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
const int group_offset = act_k_offset / GroupSize;
// TODO: use make_divisible
const uint32_t* blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
const half_t* scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;
const half_t* zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
const half_t* inputs_ptr = inputs + act_k_offset;
const uint32_t *blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
const half_t *scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;
const half_t *zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
const half_t *inputs_ptr = inputs + act_k_offset;
const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;
const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;
const int scale_forward_step = act_forward_step / GroupSize * OC;
// Main loop iteration, each block completes the outputs for several OCs
for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread)
{
// Load qweight, scales and scaled_zeros
#pragma unroll
for (int idx = 0; idx < NPerBlock; ++idx)
{
for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread) {
// Load qweight, scales and scaled_zeros
#pragma unroll
for (int idx = 0; idx < NPerBlock; ++idx) {
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
*((float4*)(local_qweights)) =
*((float4*)(blk_weight_ptr + (idx * kInterleave * IC + kk)/ PACK_FACTOR));
local_scale[idx] = *(scale_ptr + idx * kInterleave);
local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
// Map int4 qweight to fp format
#pragma unroll
for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i)
{
*((float4 *)(local_qweights)) = *((float4 *)(blk_weight_ptr + (idx * kInterleave * IC + kk) / PACK_FACTOR));
local_scale[idx] = *(scale_ptr + idx * kInterleave);
local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
// Map int4 qweight to fp format
#pragma unroll
for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i) {
// Converts 32 bits (8 x int4) to 8 fp16
dequantize_s4_to_fp16x2(*reinterpret_cast<half2_t *>(local_qweights + i), reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
dequantize_s4_to_fp16x2(*reinterpret_cast<half2_t *>(local_qweights + i),
reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
}
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
for (int i = 0; i < kShuffleContinous; ++i)
{
#pragma unroll
for (int j = 0; j < kShuffleStrided; ++j)
{
half2_t w =
*reinterpret_cast<half2_t*>(
half_weight_buffer + (i + j * kShuffleContinous)* kShuffleBasicTile
);
w = __hfma2(w, half2half2(local_scale[idx]), half2half2(local_scaled_zeros[idx]));
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0)
* NPerBlock + idx]
= w.x;
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1)
* NPerBlock + idx]
= w.y;
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
for (int i = 0; i < kShuffleContinous; ++i) {
#pragma unroll
for (int j = 0; j < kShuffleStrided; ++j) {
half2_t w = *reinterpret_cast<half2_t *>(half_weight_buffer +
(i + j * kShuffleContinous) * kShuffleBasicTile);
w = __hfma2(w, half2half2(local_scale[idx]), half2half2(local_scaled_zeros[idx]));
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0) * NPerBlock + idx] = w.x;
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1) * NPerBlock + idx] = w.y;
}
}
}
#pragma unroll
for (int batch_idx = 0; batch_idx < Batch; ++batch_idx)
{
const half_t* local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
for (int idx = 0; idx < kElemsPerThread / 8; ++idx)
{
}
}
#pragma unroll
for (int batch_idx = 0; batch_idx < Batch; ++batch_idx) {
const half_t *local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
for (int idx = 0; idx < kElemsPerThread / 8; ++idx) {
// load activation, 8 halves (128 bits) / step.
*((float4*)(local_inputs + idx * 8)) = *((float4*)(local_inputs_ptr + idx * 8));
*((float4 *)(local_inputs + idx * 8)) = *((float4 *)(local_inputs_ptr + idx * 8));
}
// Perform the MACs
#pragma unroll
for (int x = 0; x < NPerBlock / 2; ++x)
{
#pragma unroll
for (int y = 0; y < kElemsPerThread; ++y)
{
accum2_t prod = cuda_cast<accum2_t>(__hmul2(*reinterpret_cast<half2_t*>(dequantized_weight + y * NPerBlock + x * 2), half2half2(local_inputs[y])));
*reinterpret_cast<accum2_t*>(psum + batch_idx * NPerBlock + x * 2)
= prod + *reinterpret_cast<accum2_t*>(psum + batch_idx * NPerBlock + x * 2);
// Perform the MACs
#pragma unroll
for (int x = 0; x < NPerBlock / 2; ++x) {
#pragma unroll
for (int y = 0; y < kElemsPerThread; ++y) {
accum2_t prod = cuda_cast<accum2_t>(
__hmul2(*reinterpret_cast<half2_t *>(dequantized_weight + y * NPerBlock + x * 2),
half2half2(local_inputs[y])));
*reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2) =
prod + *reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2);
// *reinterpret_cast<half2_t*>(psum + batch_idx * NPerBlock + x * 2)
// = __hfma2(*reinterpret_cast<half2_t*>(dequantized_weight + y * NPerBlock + x * 2),
// half2half2(local_inputs[y]),
......@@ -243,13 +223,11 @@ __global__ void gemv_kernel(
warp_reduce<accum_t, Num, WARP_SIZE>(psum, out_smem);
// Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize)
{
for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize) {
int batch_idx = i / (NPerBlock * kInterleave);
int oc_idx = i % (NPerBlock * kInterleave);
float acc = 0.f;
for (int j = 0; j < BlockSize / WARP_SIZE; ++j)
{
int oc_idx = i % (NPerBlock * kInterleave);
float acc = 0.f;
for (int j = 0; j < BlockSize / WARP_SIZE; ++j) {
acc += out_smem[j][i];
}
outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast<half_t>(acc);
......@@ -271,32 +249,24 @@ Returns:
out_feats: tensor of shape [B, OC];
*/
Tensor gemv_awq(
Tensor _in_feats,
Tensor _kernel,
Tensor _scaling_factors,
Tensor _zeros,
int m,
int n,
int k,
int group_size)
{
Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int m, int n, int k, int group_size) {
return dispatchFloat16(_scaling_factors.scalar_type(), [&]<typename half_t>() {
assert(isTypeMatch<half_t>(_in_feats.dtype()));
auto output_shape = _in_feats.shape.dataExtent;
auto output_shape = _in_feats.shape.dataExtent;
output_shape.back() = n;
auto in_feats = reinterpret_cast<half_t*>(_in_feats.data_ptr<half_t>());
auto kernel = reinterpret_cast<uint32_t*>(_kernel.data_ptr());
auto zeros = reinterpret_cast<half_t*>(_zeros.data_ptr<half_t>());
auto scaling_factors = reinterpret_cast<half_t*>(_scaling_factors.data_ptr<half_t>());
auto in_feats = reinterpret_cast<half_t *>(_in_feats.data_ptr<half_t>());
auto kernel = reinterpret_cast<uint32_t *>(_kernel.data_ptr());
auto zeros = reinterpret_cast<half_t *>(_zeros.data_ptr<half_t>());
auto scaling_factors = reinterpret_cast<half_t *>(_scaling_factors.data_ptr<half_t>());
Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
half_t * out_feats = reinterpret_cast<half_t *>(_out_feats.data_ptr());
static constexpr int N_PER_BLOCK = 2;
half_t *out_feats = reinterpret_cast<half_t *>(_out_feats.data_ptr());
static constexpr int N_PER_BLOCK = 2;
static constexpr int K_INTERLEAVE = 4;
static constexpr int BLOCK_SIZE = 256;
static constexpr int BLOCK_SIZE = 256;
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE);
......@@ -312,9 +282,9 @@ Tensor gemv_awq(
return;
}
if constexpr (M > 0) {
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE>
<<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
checkCUDA(cudaGetLastError());
}
});
......
......@@ -3,12 +3,5 @@
#include "common.h"
#include "Tensor.h"
Tensor gemv_awq(
Tensor _in_feats,
Tensor _kernel,
Tensor _scaling_factors,
Tensor _zeros,
int m,
int n,
int k,
int group_size);
\ No newline at end of file
Tensor
gemv_awq(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int m, int n, int k, int group_size);
......@@ -41,65 +41,54 @@
/////////////////////////////////////////////////////////////////////////////////////////////////
/// CTA-wide semaphore for inter-CTA synchronization.
class Semaphore
{
class Semaphore {
public:
int *lock;
bool wait_thread;
int state;
int *lock;
bool wait_thread;
int state;
public:
/// Implements a semaphore to wait for a flag to reach a given value
__host__ __device__ Semaphore(int *lock_, int thread_id) : lock(lock_),
wait_thread(thread_id < 0 || thread_id == 0),
state(-1)
{
}
/// Implements a semaphore to wait for a flag to reach a given value
__host__ __device__ Semaphore(int *lock_, int thread_id)
: lock(lock_), wait_thread(thread_id < 0 || thread_id == 0), state(-1) {}
/// Permit fetching the synchronization mechanism early
__device__ void fetch()
{
if (wait_thread)
{
/// Permit fetching the synchronization mechanism early
__device__ void fetch() {
if (wait_thread) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif
}
}
}
/// Gets the internal state
__device__ int get_state() const
{
return state;
}
/// Waits until the semaphore is equal to the given value
__device__ void wait(int status = 0)
{
while (__syncthreads_and(state != status))
{
fetch();
/// Gets the internal state
__device__ int get_state() const {
return state;
}
__syncthreads();
}
/// Waits until the semaphore is equal to the given value
__device__ void wait(int status = 0) {
while (__syncthreads_and(state != status)) {
fetch();
}
__syncthreads();
}
/// Updates the lock with the given result
__device__ void release(int status = 0)
{
__syncthreads();
/// Updates the lock with the given result
__device__ void release(int status = 0) {
__syncthreads();
if (wait_thread)
{
if (wait_thread) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#endif
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -16,4 +16,4 @@ inline void dispatchF16(Tensor::ScalarType type, F &&func) {
} else {
assert(false);
}
}
\ No newline at end of file
}
......@@ -53,17 +53,16 @@ inline auto dispatch(Tensor::ScalarType scalarType, F &&func) {
}
#pragma nv_diagnostic push
// warning #445-D: template parameter "scalar_t" is not used in declaring the parameter types of function template "lambda []()->auto::operator auto (*)()"
#pragma nv_diag_suppress 445
// warning #445-D: template parameter "scalar_t" is not used in declaring the parameter types of function template
// "lambda []()->auto::operator auto (*)()"
#pragma nv_diag_suppress 445
template<typename T>
inline bool isTypeMatch(Tensor::ScalarType scalarType) {
return dispatch(scalarType, []<typename scalar_t>() {
return std::is_same_v<scalar_t, T>;
});
return dispatch(scalarType, []<typename scalar_t>() { return std::is_same_v<scalar_t, T>; });
}
#pragma nv_diagnostic pop
template<typename F, int ...N>
template<typename F, int... N>
inline auto dispatchVal(int val, std::integer_sequence<int, N...>, F &&func) {
auto call = [&]<int i>() {
if (val == i) {
......@@ -82,5 +81,4 @@ inline auto dispatchBool(bool val, F &&func) {
}
}
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) dispatchFloat(TYPE, [&]<typename scalar_t>() { __VA_ARGS__(); });
\ No newline at end of file
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) dispatchFloat(TYPE, [&]<typename scalar_t>() { __VA_ARGS__(); });
......@@ -219,28 +219,30 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
// weight = weight.copy(weight.device());
dispatchF16(weight.dtype(), [&]<typename half_t>() {
using ElementOutput = half_t;
using ElementAccumulator = half_t;
using ElementOutput = half_t;
using ElementAccumulator = half_t;
using ElementComputeEpilogue = half_t;
using ElementInputA = half_t;
using ElementInputB = half_t;
using ElementInputA = half_t;
using ElementInputB = half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, 64>;
using FilterShape = cutlass::MatrixShape<3, 3>;
using FilterShape = cutlass::MatrixShape<3, 3>;
using ThreadblockShape = cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, 64, FilterShape::kCount>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, FilterShape::kCount>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, FilterShape::kCount>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
......@@ -249,15 +251,14 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
FilterShape,
WarpShape,
InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementOutput, ElementComputeEpilogue>,
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>,
cutlass::epilogue::thread::LinearCombination<ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementOutput,
ElementComputeEpilogue>,
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::IteratorAlgorithm::kFixedStrideDilation,
......@@ -267,15 +268,14 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
using DeviceKernel = typename cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C__),
cutlass::Tensor4DCoord(1, 1, 1, 1),
cutlass::MatrixCoord(1, 1),
cutlass::MatrixCoord(1, 1),
cutlass::conv::Mode::kCrossCorrelation,
1,
C_ // groups
cutlass::conv::Conv2dProblemSize problem_size(cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C__),
cutlass::Tensor4DCoord(1, 1, 1, 1),
cutlass::MatrixCoord(1, 1),
cutlass::MatrixCoord(1, 1),
cutlass::conv::Mode::kCrossCorrelation,
1,
C_ // groups
);
const int P = problem_size.P;
......@@ -292,11 +292,17 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
Tensor tmp_weight = Tensor::empty_like(weight);
cutlass::TensorRef<ElementInputA, LayoutInputA> a_ref(input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(2), input.stride(1), input.stride(0)));
cutlass::TensorRef<ElementInputB, LayoutInputB> b_ref(weight.data_ptr<ElementInputB>(), LayoutInputB(weight.stride(2), weight.stride(1), weight.stride(0)));
cutlass::TensorRef<ElementOutput, LayoutOutput> c_ref(bias.valid() ? bias.data_ptr<ElementOutput>() : out.data_ptr<ElementOutput>(), LayoutOutput(0, 0, 0));
cutlass::TensorRef<ElementOutput, LayoutOutput> d_ref(out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(2), out.stride(1), out.stride(0)));
cutlass::TensorRef<ElementOutput, LayoutOutput> tmpw_ref(tmp_weight.data_ptr<ElementOutput>(), LayoutOutput(tmp_weight.stride(2), tmp_weight.stride(1), tmp_weight.stride(0)));
cutlass::TensorRef<ElementInputA, LayoutInputA> a_ref(
input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(2), input.stride(1), input.stride(0)));
cutlass::TensorRef<ElementInputB, LayoutInputB> b_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB(weight.stride(2), weight.stride(1), weight.stride(0)));
cutlass::TensorRef<ElementOutput, LayoutOutput> c_ref(
bias.valid() ? bias.data_ptr<ElementOutput>() : out.data_ptr<ElementOutput>(), LayoutOutput(0, 0, 0));
cutlass::TensorRef<ElementOutput, LayoutOutput> d_ref(
out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(2), out.stride(1), out.stride(0)));
cutlass::TensorRef<ElementOutput, LayoutOutput> tmpw_ref(
tmp_weight.data_ptr<ElementOutput>(),
LayoutOutput(tmp_weight.stride(2), tmp_weight.stride(1), tmp_weight.stride(0)));
typename DeviceKernel::Arguments arguments{
problem_size,
......@@ -315,7 +321,6 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
BufferCUDA workspace(workspace_size);
auto stream = getCurrentCUDAStream();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
......@@ -333,4 +338,4 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
});
return out;
}
\ No newline at end of file
}
......@@ -5,4 +5,4 @@
// Tensor depthwise_conv2d_kernel(Tensor A, Tensor B);
Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias);
\ No newline at end of file
Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias);
......@@ -9,18 +9,16 @@
using spdlog::fmt_lib::format;
Tensor gemm_batched_fp16(
Tensor a, // FP16 row-major [(... batch ...), M, K]
Tensor b, // FP16 col-major [(... batch ...), N, K]
Tensor out // FP32 row-major [(... batch ...), M, N]
)
{
const int M = a.shape[-2];
const int K = a.shape[-1];
const int N = a.shape[-2];
Tensor gemm_batched_fp16(Tensor a, // FP16 row-major [(... batch ...), M, K]
Tensor b, // FP16 col-major [(... batch ...), N, K]
Tensor out // FP32 row-major [(... batch ...), M, N]
) {
const int M = a.shape[-2];
const int K = a.shape[-1];
const int N = a.shape[-2];
const int batch = a.numel() / (M * K);
using ElementInput = cutlass::half_t;
using ElementInput = cutlass::half_t;
using ElementOutput = float;
using LayoutA = cutlass::layout::RowMajor;
......@@ -28,18 +26,23 @@ Tensor gemm_batched_fp16(
using LayoutO = cutlass::layout::RowMajor;
using Gemm = cutlass::gemm::device::GemmBatched<
ElementInput, LayoutA,
ElementInput, LayoutB,
ElementOutput, LayoutO,
ElementOutput,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
ElementInput,
LayoutA,
ElementInput,
LayoutB,
ElementOutput,
LayoutO,
ElementOutput,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementOutput, ElementOutput>,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
cutlass::epilogue::thread::LinearCombination<ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementOutput,
ElementOutput>,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
2>;
auto sizeA = cutlass::MatrixCoord(M, K);
......@@ -48,8 +51,8 @@ Tensor gemm_batched_fp16(
if (!out.valid()) {
auto outShape = TensorShape(a.shape.dataExtent);
outShape[-1] = N;
out = Tensor::empty(outShape, Tensor::FP32, a.device());
outShape[-1] = N;
out = Tensor::empty(outShape, Tensor::FP32, a.device());
}
assert(K == b.shape[-1]);
......@@ -62,28 +65,23 @@ Tensor gemm_batched_fp16(
cutlass::gemm::GemmCoord problemSize(M, N, K);
cutlass::TensorRef<ElementInput, LayoutA> refA(
a.data_ptr<ElementInput>(), LayoutA(a.stride(-2)));
cutlass::TensorRef<ElementInput, LayoutB> refB(
b.data_ptr<ElementInput>(), LayoutB(b.stride(-2)));
cutlass::TensorRef<ElementOutput, LayoutO> refO(
out.data_ptr<ElementOutput>(), LayoutO(out.stride(-2)));
typename Gemm::Arguments arguments{
problemSize,
refA,
(int)a.stride(-3),
refB,
(int)b.stride(-3),
refO,
(int)out.stride(-3),
refO,
(int)out.stride(-3),
{ ElementOutput(1), ElementOutput(0) },
batch
};
Gemm op;
cutlass::TensorRef<ElementInput, LayoutA> refA(a.data_ptr<ElementInput>(), LayoutA(a.stride(-2)));
cutlass::TensorRef<ElementInput, LayoutB> refB(b.data_ptr<ElementInput>(), LayoutB(b.stride(-2)));
cutlass::TensorRef<ElementOutput, LayoutO> refO(out.data_ptr<ElementOutput>(), LayoutO(out.stride(-2)));
typename Gemm::Arguments arguments{problemSize,
refA,
(int)a.stride(-3),
refB,
(int)b.stride(-3),
refO,
(int)out.stride(-3),
refO,
(int)out.stride(-3),
{ElementOutput(1), ElementOutput(0)},
batch};
Gemm op;
BufferCUDA workspace(Gemm::get_workspace_size(arguments));
cutlass::Status status = op.can_implement(arguments);
......@@ -102,4 +100,4 @@ Tensor gemm_batched_fp16(
}
return out;
}
\ No newline at end of file
}
......@@ -3,8 +3,7 @@
#include "common.h"
#include "Tensor.h"
Tensor gemm_batched_fp16(
Tensor a, // FP16 row-major [(... batch ...), M, K]
Tensor b, // FP16 col-major [(... batch ...), N, K]
Tensor out // FP32 row-major [(... batch ...), M, N]
);
\ No newline at end of file
Tensor gemm_batched_fp16(Tensor a, // FP16 row-major [(... batch ...), M, K]
Tensor b, // FP16 col-major [(... batch ...), N, K]
Tensor out // FP32 row-major [(... batch ...), M, N]
);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment