Commit b1b44398 authored by Samuel Tesfai's avatar Samuel Tesfai
Browse files

Fixing merges

parents 004e4e31 4b9c2e03
......@@ -8,11 +8,11 @@
using spdlog::fmt_lib::format;
using namespace nunchaku;
SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, Tensor::ScalarType dtype, Device device) :
SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim),
dim_pad(ceilDiv(dim, 128) * 128),
qkv_proj(dim, dim_pad * 3, bias, dtype, device),
out_proj(dim_pad, dim, bias, dtype, device),
qkv_proj(dim, dim_pad * 3, bias, use_fp4, dtype, device),
out_proj(dim_pad, dim, bias, use_fp4, dtype, device),
pag_to_v(std::nullopt)
{
registerChildren
......@@ -21,7 +21,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, Tensor::S
;
if (pag) {
pag_to_v.emplace(dim, dim_pad, bias, dtype, device);
pag_to_v.emplace(dim, dim_pad, bias, use_fp4, dtype, device);
registerChildren(pag_to_v.value(), "pag_to_v");
}
}
......@@ -63,7 +63,11 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
qkv_proj.wscales,
{}, {}, qact.lora_act, qkv_proj.lora_up, {}, {}, {}, {}, {}, qkv_proj.bias, {},
vk, q,
qact.is_unsigned, qkv_proj.lora_scales, false);
qact.is_unsigned, qkv_proj.lora_scales, false,
qkv_proj.use_fp4,
*qkv_proj.wtscale.data_ptr<float>(),
qkv_proj.wcscales.numel() > 0 ? qkv_proj.wcscales : Tensor{}
);
debug("vk", vk);
debug("q", q);
......@@ -121,11 +125,11 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
return out;
}
MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, Tensor::ScalarType dtype, Device device) :
MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device) :
num_heads(num_heads), head_dim(head_dim),
q_linear(num_heads * head_dim, num_heads * head_dim, true, dtype, device),
q_linear(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device),
kv_linear(num_heads * head_dim, num_heads * head_dim * 2, true, dtype, device),
out_proj(num_heads * head_dim, num_heads * head_dim, true, dtype, device)
out_proj(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device)
{
registerChildren
(q_linear, "q_linear")
......@@ -173,11 +177,11 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
return out_proj.forward(attn_output);
}
SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, Tensor::ScalarType dtype, Device device) :
SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device) :
in_features(in_features), hidden_features(hidden_features),
inverted_conv(in_features, hidden_features * 2, true, dtype, device),
inverted_conv(in_features, hidden_features * 2, true, use_fp4, dtype, device),
depth_conv(hidden_features * 2, true, dtype, device),
point_conv(hidden_features, in_features, false, dtype, device)
point_conv(hidden_features, in_features, false, use_fp4, dtype, device)
{
registerChildren
(inverted_conv, "inverted_conv")
......@@ -200,11 +204,11 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
return point_conv.forward_quant(qact);
}
SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, Tensor::ScalarType dtype, Device device) :
SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) :
hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads),
attn(hidden_size, false, pag, dtype, device),
cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, dtype, device),
ff(hidden_size, intermediate_size, dtype, device),
attn(hidden_size, false, pag, use_fp4, dtype, device),
cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, use_fp4, dtype, device),
ff(hidden_size, intermediate_size, use_fp4, dtype, device),
norm1(hidden_size, 1e-6, false, dtype, device),
norm2(hidden_size, 1e-6, false, dtype, device)
{
......@@ -313,6 +317,7 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
ceilDiv(int(round(config.expand_ratio * inner_dim)), 64) * 64,
config.num_cross_attention_heads,
std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(),
config.use_fp4,
dtype, device
));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
......
......@@ -7,7 +7,7 @@
class SanaLinearAttention : public Module {
public:
SanaLinearAttention(int dim, bool bias, bool pag, Tensor::ScalarType dtype, Device device);
SanaLinearAttention(int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x, Tensor out = {});
Tensor forward_pag(Tensor x, bool cfg);
......@@ -25,7 +25,7 @@ private:
class MultiHeadCrossAttention : public Module {
public:
MultiHeadCrossAttention(int num_heads, int head_dim, Tensor::ScalarType dtype, Device device);
MultiHeadCrossAttention(int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt);
......@@ -41,7 +41,7 @@ private:
class SanaGLUMBConv : public Module {
public:
SanaGLUMBConv(int in_features, int hidden_features, Tensor::ScalarType dtype, Device device);
SanaGLUMBConv(int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x, int H, int W);
......@@ -57,7 +57,7 @@ private:
class SanaLinearTransformerBlock : public Module {
public:
SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, Tensor::ScalarType dtype, Device device);
SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg);
......@@ -83,6 +83,7 @@ struct SanaConfig {
int num_cross_attention_heads;
double expand_ratio;
std::vector<int> pag_layers;
bool use_fp4;
};
class SanaModel : public Module {
......
......@@ -117,6 +117,8 @@ void SafeTensors::parseHeader() {
{ "I8", Tensor::INT8 },
{ "I32", Tensor::INT32 },
{ "I64", Tensor::INT64 },
{ "F8_E4M3", Tensor::FP8_E4M3 },
{ "F8_E5M2", Tensor::FP8_E5M2 },
};
auto check = [](bool cond, std::source_location location = std::source_location::current()) {
......
......@@ -218,7 +218,8 @@ public:
enum ScalarType {
INVALID_SCALAR_TYPE,
INT8, INT16, INT32, INT64,
FP16, FP32, BF16
FP16, FP32, BF16,
FP8_E4M3, FP8_E5M2,
};
struct TensorOptions {
......@@ -546,6 +547,8 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
{FP16, 2},
{FP32, 4},
{BF16, 2},
{FP8_E4M3, 1},
{FP8_E5M2, 1},
};
struct TensorsProvider {
......
......@@ -9,6 +9,7 @@
#include <memory>
#include <source_location>
#include <vector>
#include <list>
#include <stack>
#include <map>
#include <unordered_map>
......@@ -79,6 +80,15 @@ constexpr T ceilDiv(T a, T b) {
return (a + b - 1) / b;
}
template<typename T>
constexpr int log2Up(T value) {
if (value <= 0)
return 0;
if (value == 1)
return 0;
return log2Up((value + 1) / 2) + 1;
}
struct CUBLASWrapper {
cublasHandle_t handle = nullptr;
......
......@@ -28,8 +28,9 @@ Tensor from_torch(at::Tensor input) {
{ at::ScalarType::Float, Tensor::FP32 },
{ at::ScalarType::Half, Tensor::FP16 },
{ at::ScalarType::BFloat16, Tensor::BF16 },
{ at::ScalarType::Short, Tensor::INT16 },
{ at::ScalarType::Float8_e4m3fn, Tensor::FP8_E4M3 },
{ at::ScalarType::Float8_e5m2, Tensor::FP8_E5M2 },
};
result.scalarType = mapType.at(input.scalar_type());
......@@ -55,8 +56,9 @@ at::Tensor to_torch(Tensor input) {
{ Tensor::FP32, at::ScalarType::Float },
{ Tensor::FP16, at::ScalarType::Half },
{ Tensor::BF16, at::ScalarType::BFloat16 },
{ Tensor::INT16, at::ScalarType::Short },
{ Tensor::FP8_E4M3, at::ScalarType::Float8_e4m3fn },
{ Tensor::FP8_E5M2, at::ScalarType::Float8_e5m2 },
};
c10::TensorOptions opts(mapType.at(input.scalar_type()));
......
......@@ -140,8 +140,10 @@ __global__ void gemv_kernel(
for (int i = 0; i < Num; ++i)
psum[i] = static_cast<accum_t>(0.f);
extern __shared__ uint8_t shmem[];
float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
// extern __shared__ uint8_t shmem[];
// float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
__shared__ float out_smem[BlockSize / WARP_SIZE * 2][Num * kInterleave];
const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
......
......@@ -319,10 +319,10 @@ public:
int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
if (pred) {
//if (pred) {
// out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]);
out[i] = load(&act[((k * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId]);
}
out[i] = load_pred(&act[((k * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId], pred);
//}
}
}
......@@ -336,12 +336,12 @@ public:
// int offset = K / WARP_K * WARP_SIZE;
#pragma unroll
for (int i = 0; i < WARP_N_TILES; i++) {
if (pred) {
//if (pred) {
// out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]);
// out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]);
out[i] = load(&ptr[i * WARP_SIZE]);
out[i] = load_pred(&ptr[i * WARP_SIZE], pred);
// ptr += offset;
}
//}
}
}
......@@ -352,11 +352,11 @@ public:
int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int i = 0; i < ASCALES_NUM_PACKS; i++) {
if (pred && laneId < ASCALES_VALID_LANES) {
// if (pred && laneId < ASCALES_VALID_LANES) {
// out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i * ASCALES_VALID_LANES + laneId];
out[i] = ascales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES + i * ASCALES_VALID_LANES + laneId];
out[i] = load_pred(&ascales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES + i * ASCALES_VALID_LANES + laneId], pred && laneId < ASCALES_VALID_LANES);
}
// }
}
}
......@@ -373,13 +373,13 @@ public:
#pragma unroll
for (int i = 0; i < WSCALES_NUM_PACKS; i++) {
if (pred && laneId < WSCALES_VALID_LANES) {
// if (pred && laneId < WSCALES_VALID_LANES) {
// out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId];
// out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]);
out[i] = load(&wscales[(group * WSCALES_NUM_PACKS + i) * WSCALES_VALID_LANES + laneId]);
out[i] = load_pred(&wscales[(group * WSCALES_NUM_PACKS + i) * WSCALES_VALID_LANES + laneId], pred && laneId < WSCALES_VALID_LANES);
// out[i] = load(&ptr[i * WSCALES_VALID_LANES]);
}
// }
}
}
......@@ -400,7 +400,7 @@ public:
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
}
template<typename F>
template<bool FAST_I2F = false, typename F>
__device__ __forceinline__
static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, fpsum_warp &fpsum) {
const int laneId = threadIdx.x % WARP_SIZE;
......@@ -429,12 +429,31 @@ public:
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
fsum.data[0] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1]))), __hmul2(asx[i], ws1), fsum.data[0]);
fsum.data[1] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3]))), __hmul2(asy[i], ws1), fsum.data[1]);
fsum.data[2] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5]))), __hmul2(asx[i], ws2), fsum.data[2]);
fsum.data[3] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7]))), __hmul2(asy[i], ws2), fsum.data[3]);
auto scale_fma_normal = [&]() ALWAYSINLINE {
fsum.data[0] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1]))), __hmul2(asx[i], ws1), fsum.data[0]);
fsum.data[1] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3]))), __hmul2(asy[i], ws1), fsum.data[1]);
fsum.data[2] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5]))), __hmul2(asx[i], ws2), fsum.data[2]);
fsum.data[3] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7]))), __hmul2(asy[i], ws2), fsum.data[3]);
};
// should be faster on sm_80
auto scale_fma_fast = [&]() ALWAYSINLINE {
fsum.data[0] = __hfma2(float22half2<half2_t>(make_float2(int2float_fast(psum.data[0]), int2float_fast(psum.data[1]))), __hmul2(asx[i], ws1), fsum.data[0]);
fsum.data[1] = __hfma2(float22half2<half2_t>(make_float2(int2float_fast(psum.data[2]), int2float_fast(psum.data[3]))), __hmul2(asy[i], ws1), fsum.data[1]);
fsum.data[2] = __hfma2(float22half2<half2_t>(make_float2(int2float_fast(psum.data[4]), int2float_fast(psum.data[5]))), __hmul2(asx[i], ws2), fsum.data[2]);
fsum.data[3] = __hfma2(float22half2<half2_t>(make_float2(int2float_fast(psum.data[6]), int2float_fast(psum.data[7]))), __hmul2(asy[i], ws2), fsum.data[3]);
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 800
if constexpr (FAST_I2F) {
scale_fma_fast();
} else {
scale_fma_normal();
}
#else
scale_fma_normal();
#endif
// if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
......@@ -575,9 +594,9 @@ public:
(plugins(i * INSN_M + row, pack), ...);
bool pred = i * INSN_M + row < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) {
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack);
}
// if (pred) {
store_pred(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack, pred);
// }
}
__syncwarp();
......@@ -602,9 +621,9 @@ public:
(plugins(i * INSN_M + 8 + row, pack), ...);
bool pred = i * INSN_M + 8 + row < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) {
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + 8 + row) * stride + laneId * PACK_SIZE]), pack);
}
// if (pred) {
store_pred(reinterpret_cast<pack_t *>(&output[(i * INSN_M + 8 + row) * stride + laneId * PACK_SIZE]), pack, pred);
// }
}
__syncwarp();
}
......@@ -680,33 +699,61 @@ public:
}
};
template<bool USE_BIAS = true, bool USE_SCALE = false>
struct EpilogueBias {
struct Arguments {
const packed_wscale_t *bias; // [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t
const packed_wscale_t *scale;
};
__device__ __forceinline__
void apply_bias(fpsum_warp &fpsum, int M, int N, int K, const packed_wscale_t *bias) {
void apply_bias(fpsum_warp &fpsum, int M, int N, int K, const packed_wscale_t *bias, const packed_wscale_t *scale) {
const int laneId = threadIdx.x % WARP_SIZE;
// if (laneId == 0) {
// printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias);
// }
wscale_warp b;
load_wscale(bias, 0, N, b, true);
wscale_warp b, s;
if constexpr (USE_BIAS) {
load_wscale(bias, 0, N, b, true);
}
if constexpr (USE_SCALE) {
load_wscale(scale, 0, N, s, true);
}
for (int j = 0; j < WARP_N_TILES; j++) {
half2_t b1 = broadcast_wscale(b, j * 4, laneId);
half2_t b2 = broadcast_wscale(b, j * 4 + 2, laneId);
half2_t b1, b2;
half2_t s1, s2;
if constexpr (USE_BIAS) {
b1 = broadcast_wscale(b, j * 4, laneId);
b2 = broadcast_wscale(b, j * 4 + 2, laneId);
}
if constexpr (USE_SCALE) {
s1 = broadcast_wscale(s, j * 4, laneId);
s2 = broadcast_wscale(s, j * 4 + 2, laneId);
}
for (int i = 0; i < WARP_M_TILES; i++) {
auto &fsum = fpsum[i * WARP_N_TILES + j];
fsum.data[0] = __hadd2(fsum.data[0], b1);
fsum.data[1] = __hadd2(fsum.data[1], b1);
fsum.data[2] = __hadd2(fsum.data[2], b2);
fsum.data[3] = __hadd2(fsum.data[3], b2);
if constexpr (USE_SCALE && USE_BIAS) {
fsum.data[0] = __hfma2(fsum.data[0], s1, b1);
fsum.data[1] = __hfma2(fsum.data[1], s1, b1);
fsum.data[2] = __hfma2(fsum.data[2], s2, b2);
fsum.data[3] = __hfma2(fsum.data[3], s2, b2);
} else if constexpr (USE_SCALE) {
fsum.data[0] = __hmul2(fsum.data[0], s1);
fsum.data[1] = __hmul2(fsum.data[1], s1);
fsum.data[2] = __hmul2(fsum.data[2], s2);
fsum.data[3] = __hmul2(fsum.data[3], s2);
} else if constexpr (USE_BIAS) {
fsum.data[0] = __hadd2(fsum.data[0], b1);
fsum.data[1] = __hadd2(fsum.data[1], b1);
fsum.data[2] = __hadd2(fsum.data[2], b2);
fsum.data[3] = __hadd2(fsum.data[3], b2);
}
}
}
}
......@@ -714,10 +761,13 @@ public:
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
const int bn = binfo.bn;
apply_bias(
fpsum, M, N, K,
args.bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES
);
if constexpr (USE_BIAS || USE_SCALE) {
apply_bias(
fpsum, M, N, K,
args.bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES,
args.scale + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES
);
}
}
};
......@@ -797,7 +847,21 @@ public:
using typename Base::unpack_fpsum; \
using typename Base::EpilogueDefault; \
using typename Base::EpilogueNop; \
using typename Base::EpilogueBias;
template<bool USE_BIAS, bool USE_SCALE> \
using EpilogueBias = typename Base::EpilogueBias<USE_BIAS, USE_SCALE>; \
using Base::mma_f16xf16_f32; \
using Base::packed_fp32_to_fp16; \
using Base::packed_fp16_to_fp32; \
using Base::load_act; \
using Base::load_wgt; \
using Base::load_ascale; \
using Base::load_wscale; \
using Base::broadcast_wscale; \
using Base::broadcast_ascale; \
using Base::apply_scales; \
using Base::pack_ascales; \
using Base::pack_wscales; \
using Base::apply_act;
template<typename kernel, typename ...T>
......
......@@ -43,6 +43,41 @@ static T load(const T *addr) {
return *addr;
}
template<typename T>
__device__ __forceinline__
static T load_pred(const T *addr, bool pred) {
if constexpr (sizeof(T) == 4) {
uint32_t data;
asm volatile (
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
"@loadpred ld.global.nc.b32 %0, [%1];"
"}" : "=r"(data) : "l"(addr), "r"((int)pred));
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 8) {
uint2 data;
asm volatile (
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
"}" : "=r"(data.x), "=r"(data.y) : "l"(addr), "r"((int)pred));
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 16) {
uint4 data;
asm volatile (
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
"@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
"}" : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) : "l"(addr), "r"((int)pred));
return *reinterpret_cast<T *>(&data);
}
T result;
if (pred) {
result = *addr;
}
return result;
}
template<bool shmem = false, typename T>
__device__ __forceinline__
static void store(T *addr, T val) {
......@@ -76,6 +111,39 @@ static void store(T *addr, T val) {
*addr = val;
}
template<typename T>
__device__ __forceinline__
static void store_pred(T *addr, T val, bool pred) {
if constexpr (sizeof(T) == 4) {
uint32_t data = *reinterpret_cast<uint32_t *>(&val);
asm volatile (
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
"}" :: "r"((int)pred), "l"(addr), "r"(data));
return;
}
if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val);
asm volatile (
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}" :: "r"((int)pred), "l"(addr), "r"(data.x), "r"(data.y));
return;
}
if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val);
asm volatile (
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}" :: "r"((int)pred), "l"(addr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w));
return;
}
if (pred) {
*addr = val;
}
}
__device__ __forceinline__
static float2 half22float2(half2 val) {
return __half22float2(val);
......@@ -159,6 +227,21 @@ uint32_t quantize_float2<8, false>(float2 value) {
return result;
}
__device__ __forceinline__
uint32_t quantize_float2_fp4(float2 value) {
uint32_t result;
asm volatile ("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }" : "=r"(result) : "f"(value.y), "f"(value.x));
return result;
}
__device__ __forceinline__
uint32_t quantize_float4_fp8(float4 value) {
uint16_t lo, hi;
asm volatile ("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
asm volatile ("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
return uint32_t(lo) | (uint32_t(hi) << 16);
}
__device__ __forceinline__
static float cuda_tanhf(float x) {
float result;
......@@ -271,4 +354,14 @@ static void unrolled_loop(F &&lambda) {
call(std::make_integer_sequence<int, cnt>());
}
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
__device__ __forceinline__
static float int2float_fast(int val) {
float fval;
// fval = (val & 0x7FFFFF) ^ 0x4B400000
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=f"(fval) : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
return fval - 12582912.0f;
}
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -36,9 +36,23 @@ void gemm_w4a4(
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales
) {
invoke_launch(ascales.dtype(), [&]<typename Config>() {
Tensor::ScalarType dtype = Tensor::INVALID_SCALAR_TYPE;
if (!fp4) {
dtype = ascales.dtype();
} else {
for (auto tensor : {out, bias, lora_up, lora_down, poolout, wcscales}) {
if (tensor.valid()) {
assert(dtype == Tensor::INVALID_SCALAR_TYPE || dtype == tensor.dtype());
dtype = tensor.dtype();
}
}
}
invoke_launch(dtype, [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::gemm_w4a4(
act,
wgt,
......@@ -61,7 +75,10 @@ void gemm_w4a4(
out_linearattn,
act_unsigned,
lora_scales,
fuse_silu
fuse_silu,
fp4,
alpha,
wcscales
);
});
}
......@@ -72,10 +89,10 @@ void linearattn_vk_mul_q(Tensor q, Tensor vk) {
});
}
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu) {
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
invoke_launch(input.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(
input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu
input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4
);
});
}
......
This diff is collapsed.
......@@ -12,6 +12,8 @@ class GEMM_W4A4_Launch {
using packed_wgt_t = typename GEMM::packed_wgt_t;
using packed_ascale_t = typename GEMM::packed_ascale_t;
using packed_wscale_t = typename GEMM::packed_wscale_t;
using packed_amscale_t = typename GEMM::packed_amscale_t;
using packed_wmscale_t = typename GEMM::packed_wmscale_t;
using packed_fpsum_t = typename GEMM::packed_fpsum_t;
using half_t = typename GEMM::half_t;
......@@ -38,9 +40,12 @@ public:
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales // packed ws [N]
);
static void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu);
static void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4);
static void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
static void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
......
......@@ -30,7 +30,10 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales // packed ws [N]
) {
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
......@@ -68,58 +71,111 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
std::swap(grid.x, grid.y);
}
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
dispatchBool(fp4, [&]<bool USE_FP4>() {
// test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// }, args);
using kernel = typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>;
auto func = invoke_kernel<kernel,
const packed_act_t *,
const packed_wgt_t *,
const packed_ascale_t *,
const packed_wscale_t *,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
// constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
if constexpr (!USE_FP4) {
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>,
const packed_act_t *,
const packed_wgt_t *,
const packed_ascale_t *,
const packed_wscale_t *,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
assert(alpha == 1.0f);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(),
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return;
}
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
if constexpr (USE_FP4) {
dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
assert(!act_unsigned);
auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>,
const packed_act_t *,
const packed_wgt_t *,
const packed_amscale_t *,
const packed_wmscale_t *,
float,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
assert(ascales.dtype() == Tensor::FP8_E4M3);
assert(wscales.dtype() == Tensor::FP8_E4M3);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_amscale_t>(),
wscales.data_ptr<packed_wmscale_t>(),
alpha,
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return;
}
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(),
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
// if constexpr (USE_FP4 && !FP4_AVAILABLE) {
// throw std::runtime_error("FP4 kernel is not available");
// }
});
};
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
if (!bias.valid()) {
return launch.template operator()<NextEpilogue>(nextArgs);
}
assert(bias.numel() == N);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = typename GEMM::EpilogueCombination<typename GEMM::EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({
typename GEMM::EpilogueBias::Arguments{
.bias = bias.data_ptr<packed_wscale_t>(),
},
nextArgs,
{}
assert(!bias.valid() || bias.numel() == N);
assert(!wcscales.valid() || wcscales.numel() == N);
dispatchBool(bias.valid(), [&]<bool USE_BIAS>() {
dispatchBool(wcscales.valid(), [&]<bool USE_SCALE>() {
using EpilogueBias = typename GEMM::EpilogueBias<USE_BIAS, USE_SCALE>;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = typename GEMM::EpilogueCombination<EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({
typename EpilogueBias::Arguments{
.bias = USE_BIAS ? bias.data_ptr<packed_wscale_t>() : nullptr,
.scale = USE_SCALE ? wcscales.data_ptr<packed_wscale_t>() : nullptr,
},
nextArgs,
{}
});
});
});
};
// auto launch_bias = launch;
......@@ -206,29 +262,32 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
static constexpr float SHIFT_GELU = 0.171875f;
dispatchBool(fp4, [&]<bool USE_FP4>() {
constexpr bool USE_UNSIGNED = !USE_FP4;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
auto argsQuantize = typename EpilogueQuantize::Arguments{
.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
.shift_value = USE_FP4 ? 0.0f : SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
};
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
},
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
}
});
constexpr bool USE_UNSIGNED = true;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED>;
auto argsQuantize = typename EpilogueQuantize::Arguments{
.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<packed_ascale_t>(),
.shift_value = SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
};
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
},
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
}
} else if (out_linearattn.valid()) {
assert(out_vk.valid());
......@@ -326,7 +385,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
}
template<typename Config>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu) {
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
const int actualM = input.numel() / input.shape[-1];
const int actualN = input.shape[-1];
......@@ -338,8 +397,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
assert(output.shape[-1] == N / 2);
// assert(oscales.dtype() == Tensor::FP16);
assert(isTypeMatch<half_t>(oscales.dtype()));
assert(oscales.numel() == M * N / GEMM::WARP_K);
if (fp4) {
assert(oscales.dtype() == Tensor::FP8_E4M3);
assert(oscales.numel() == M * N / GEMM::WARP_K * 4);
} else {
assert(isTypeMatch<half_t>(oscales.dtype()));
assert(oscales.numel() == M * N / GEMM::WARP_K);
}
const int rank = lora_down.shape[1];
......@@ -354,30 +418,32 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
using Lora = typename GEMM::Lora<RANK>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE>>>(
typename kernel::Arguments{
.input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.output = output.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<packed_ascale_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.M = M,
.N = N,
.actualM = actualM,
.actualN = actualN,
}
);
checkCUDA(cudaGetLastError());
dispatchBool(fp4, [&]<bool USE_FP4>() {
using Lora = typename GEMM::Lora<RANK>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE>>>(
typename kernel::Arguments{
.input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.output = output.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.M = M,
.N = N,
.actualM = actualM,
.actualN = actualN,
}
);
checkCUDA(cudaGetLastError());
});
});
});
}
......
......@@ -100,9 +100,9 @@ void gemm_w8a8(Tensor act, // [M, K]
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias, NextEpilogue, GEMM::EpilogueNop>;
using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias<true, false>, NextEpilogue, GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({
GEMM::EpilogueBias::Arguments{
GEMM::EpilogueBias<true, false>::Arguments{
.bias = bias.data_ptr<GEMM::packed_wscale_t>(),
},
nextArgs,
......
......@@ -27,11 +27,14 @@ void gemm_w4a4(
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales
);
void linearattn_vk_mul_q(Tensor q, Tensor vk);
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {}, bool fuse_glu = false);
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {}, bool fuse_glu = false, bool fp4 = false);
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment