"docs/vscode:/vscode.git/clone" did not exist on "7c3604fb68031da36567151a9bdfe69e04de44b8"
Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
...@@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x, ...@@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
const scalar_t& y) { const scalar_t& y) {
return act_first ? ACT_FN(x) * y : x * ACT_FN(y); return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
} }
// Activation and gating kernel template.
// Check if all pointers are 16-byte aligned for int4 vectorized access
__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
}
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first> bool act_first>
__global__ void act_and_mul_kernel( __global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t* x_ptr = input + token_idx * 2 * d;
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t* y_ptr = x_ptr + d;
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); scalar_t* out_ptr = out + token_idx * d;
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
// Check alignment for 128-bit vectorized access.
// All three pointers must be 16-byte aligned for safe int4 operations.
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
auto* xp = reinterpret_cast<scalar_t*>(&x);
auto* yp = reinterpret_cast<scalar_t*>(&y);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = compute<scalar_t, ACT_FN, act_first>(xp[j], yp[j]);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = compute<scalar_t, ACT_FN, act_first>(VLLM_LDG(&x_ptr[i]),
VLLM_LDG(&y_ptr[i]));
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
}
} }
} }
...@@ -120,50 +162,115 @@ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)> ...@@ -120,50 +162,115 @@ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__ void act_and_mul_kernel_with_param( __global__ void act_and_mul_kernel_with_param(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
const float param) { const float param) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t* x_ptr = input + token_idx * 2 * d;
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t* y_ptr = x_ptr + d;
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); scalar_t* out_ptr = out + token_idx * d;
out[token_idx * d + idx] = ACT_FN(x, param) * y;
// Check alignment for 128-bit vectorized access
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
auto* xp = reinterpret_cast<scalar_t*>(&x);
auto* yp = reinterpret_cast<scalar_t*>(&y);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = ACT_FN(xp[j], param) * yp[j];
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]);
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = ACT_FN(x, param) * y;
}
} }
} }
template <typename T> template <typename T>
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up, __device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
float alpha, float limit) { float alpha, float limit) {
// clamp gate: min=None, max=limit // Clamp gate to (-inf, limit] and up to [-limit, limit]
const float gate_f = (float)gate; const float g = fminf((float)gate, limit);
const float clamped_gate = gate_f > limit ? limit : gate_f; const float u = fmaxf(fminf((float)up, limit), -limit);
// glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu
// clamp up: min=-limit, max=limit return (T)((u + 1.0f) * g / (1.0f + expf(-g * alpha)));
const float up_f = (float)up;
const float clamped_up =
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
// glu = gate * sigmoid(gate * alpha)
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
const float glu = clamped_gate * sigmoid_val;
// (up + 1) * glu
return (T)((clamped_up + 1.0f) * glu);
} }
// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...].
template <typename scalar_t, template <typename scalar_t,
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float, scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
const float)> const float)>
__global__ void swigluoai_and_mul_kernel( __global__ void swigluoai_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2 * d] (interleaved)
const int d, const float alpha, const float limit) { const int d, const float alpha, const float limit) {
// For interleaved data: input has 2*d elements per token (gate/up pairs)
// output has d elements per token
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
constexpr int PAIRS = VEC_SIZE / 2; // Number of gate/up pairs per int4 load
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
// TODO: Vectorize loads and stores. const scalar_t* in_ptr = input + token_idx * 2 * d;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { scalar_t* out_ptr = out + token_idx * d;
// gate = x[..., ::2] (even indices)
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]); // Check alignment for 128-bit vectorized access on input.
// up = x[..., 1::2] (odd indices) // For output we use int2 (64-bit) which has 8-byte alignment requirement.
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]); const bool in_aligned = is_16byte_aligned(in_ptr);
const bool out_aligned =
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit); (reinterpret_cast<uintptr_t>(out_ptr) & 7) == 0; // 8-byte for int2
if (in_aligned && out_aligned && d >= PAIRS) {
// Fast path: vectorized loop
// Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs
// Each int2 store writes PAIRS output elements
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
int2* out_vec = reinterpret_cast<int2*>(out_ptr);
const int num_vecs = d / PAIRS;
const int vec_end = num_vecs * PAIRS;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 v = VLLM_LDG(&in_vec[i]);
int2 r;
auto* vp = reinterpret_cast<scalar_t*>(&v);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < PAIRS; j++) {
rp[j] = ACT_FN(vp[2 * j], vp[2 * j + 1], alpha, limit);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[2 * i]),
VLLM_LDG(&in_ptr[2 * i + 1]), alpha, limit);
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
// gate = x[..., ::2] (even indices)
const scalar_t gate = VLLM_LDG(&in_ptr[2 * idx]);
// up = x[..., 1::2] (odd indices)
const scalar_t up = VLLM_LDG(&in_ptr[2 * idx + 1]);
out_ptr[idx] = ACT_FN(gate, up, alpha, limit);
}
} }
} }
...@@ -217,10 +324,41 @@ __global__ void activation_kernel( ...@@ -217,10 +324,41 @@ __global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d] const scalar_t* __restrict__ input, // [..., d]
const int d) { const int d) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t* in_ptr = input + token_idx * d;
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); scalar_t* out_ptr = out + token_idx * d;
out[token_idx * d + idx] = ACT_FN(x);
// Check alignment for 128-bit vectorized access
const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 v = VLLM_LDG(&in_vec[i]), r;
auto* vp = reinterpret_cast<scalar_t*>(&v);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = ACT_FN(vp[j]);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i]));
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&in_ptr[idx]);
out_ptr[idx] = ACT_FN(x);
}
} }
} }
......
#ifndef CPU_ATTN_MACROS_H #ifndef CPU_ARCH_MACROS_H
#define CPU_ATTN_MACROS_H #define CPU_ARCH_MACROS_H
// x86_64 // x86_64
#ifdef __x86_64__ #ifdef __x86_64__
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
_mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \ _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \ const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \
const int n_mantissa_bits = 23; \ const int n_mantissa_bits = 23; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) __attribute__(( \ auto fast_exp = [&](const vec_op::FP32Vec16& vec) __attribute__(( \
always_inline)) { \ always_inline)) { \
__m512 values = vec.reg; \ __m512 values = vec.reg; \
auto less_ln_flt_min_mask = \ auto less_ln_flt_min_mask = \
...@@ -98,7 +98,7 @@ ...@@ -98,7 +98,7 @@
poly = vbslq_f32(hi_mask, inf, poly); \ poly = vbslq_f32(hi_mask, inf, poly); \
return vbslq_f32(lo_mask, zero, poly); \ return vbslq_f32(lo_mask, zero, poly); \
}; \ }; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) \ auto fast_exp = [&](const vec_op::FP32Vec16& vec) \
__attribute__((always_inline)) { \ __attribute__((always_inline)) { \
float32x4x4_t result; \ float32x4x4_t result; \
result.val[0] = neon_expf(vec.reg.val[0]); \ result.val[0] = neon_expf(vec.reg.val[0]); \
...@@ -110,4 +110,4 @@ ...@@ -110,4 +110,4 @@
#endif // __aarch64__ #endif // __aarch64__
#endif #endif
\ No newline at end of file
...@@ -8,10 +8,8 @@ ...@@ -8,10 +8,8 @@
#include <sys/sysctl.h> #include <sys/sysctl.h>
#endif #endif
#include "cpu_types.hpp" #include "cpu/cpu_arch_macros.h"
#include "scratchpad_manager.h" #include "cpu/utils.hpp"
#include "cpu_attn_macros.h"
#include "utils.hpp"
namespace cpu_attention { namespace cpu_attention {
enum class ISA { AMX, VEC, VEC16, NEON }; enum class ISA { AMX, VEC, VEC16, NEON };
...@@ -378,12 +376,13 @@ class AttentionScheduler { ...@@ -378,12 +376,13 @@ class AttentionScheduler {
static constexpr int32_t MaxQTileIterNum = 128; static constexpr int32_t MaxQTileIterNum = 128;
AttentionScheduler() : available_cache_size_(get_available_l2_size()) {} AttentionScheduler()
: available_cache_size_(cpu_utils::get_available_l2_size()) {}
torch::Tensor schedule(const ScheduleInput& input) const { torch::Tensor schedule(const ScheduleInput& input) const {
const bool casual = input.casual; const bool casual = input.casual;
const int32_t thread_num = omp_get_max_threads(); const int32_t thread_num = omp_get_max_threads();
const int64_t cache_size = get_available_l2_size(); const int64_t cache_size = cpu_utils::get_available_l2_size();
const int32_t max_num_q_per_iter = input.max_num_q_per_iter; const int32_t max_num_q_per_iter = input.max_num_q_per_iter;
const int32_t kv_len_alignment = input.kv_block_alignment; const int32_t kv_len_alignment = input.kv_block_alignment;
int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv; int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv;
...@@ -659,7 +658,7 @@ class AttentionScheduler { ...@@ -659,7 +658,7 @@ class AttentionScheduler {
metadata_ptr->thread_num + metadata_ptr->thread_num +
metadata_ptr->reduction_scratchpad_size_per_kv_head * metadata_ptr->reduction_scratchpad_size_per_kv_head *
(use_gqa ? input.num_heads_kv : input.num_heads_q); (use_gqa ? input.num_heads_kv : input.num_heads_q);
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc( cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(
scratchpad_size); scratchpad_size);
// metadata_ptr->print(); // metadata_ptr->print();
...@@ -667,7 +666,7 @@ class AttentionScheduler { ...@@ -667,7 +666,7 @@ class AttentionScheduler {
// test out of boundary access // test out of boundary access
// { // {
// float* cache_ptr = // float* cache_ptr =
// DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<float>(); // cpu_utils::ScratchPadManager::getl_scratchpad_manager()->get_data<float>();
// for (int64_t i = 0; i < scratchpad_size / sizeof(float); ++i) { // for (int64_t i = 0; i < scratchpad_size / sizeof(float); ++i) {
// cache_ptr[i] = std::numeric_limits<float>::quiet_NaN(); // cache_ptr[i] = std::numeric_limits<float>::quiet_NaN();
// } // }
...@@ -749,27 +748,6 @@ class AttentionScheduler { ...@@ -749,27 +748,6 @@ class AttentionScheduler {
return std::max(rounded_tile_size, round_size); return std::max(rounded_tile_size, round_size);
} }
static int64_t get_available_l2_size() {
static int64_t size = []() {
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t l2_cache_size = 0;
size_t len = sizeof(l2_cache_size);
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
l2_cache_size > 0) {
return l2_cache_size >> 1; // use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return 128LL * 1024 >> 1; // use 50% of 128KB
#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
TORCH_CHECK_NE(l2_cache_size, -1);
return l2_cache_size >> 1; // use 50% of L2 cache
#endif
}();
return size;
}
private: private:
int64_t available_cache_size_; int64_t available_cache_size_;
}; };
...@@ -1402,7 +1380,7 @@ class AttentionMainLoop { ...@@ -1402,7 +1380,7 @@ class AttentionMainLoop {
// init buffers // init buffers
void* scratchpad_ptr = void* scratchpad_ptr =
DNNLScratchPadManager::get_dnnl_scratchpad_manager() cpu_utils::ScratchPadManager::get_scratchpad_manager()
->get_data<void>(); ->get_data<void>();
AttentionScratchPad buffer_manager(thread_id, metadata, scratchpad_ptr); AttentionScratchPad buffer_manager(thread_id, metadata, scratchpad_ptr);
...@@ -1422,8 +1400,7 @@ class AttentionMainLoop { ...@@ -1422,8 +1400,7 @@ class AttentionMainLoop {
} }
} }
const int64_t available_cache_size = const int64_t available_cache_size = cpu_utils::get_available_l2_size();
AttentionScheduler::get_available_l2_size();
const int32_t default_tile_size = const int32_t default_tile_size =
AttentionScheduler::calcu_default_tile_size( AttentionScheduler::calcu_default_tile_size(
available_cache_size, head_dim, sizeof(kv_cache_t), available_cache_size, head_dim, sizeof(kv_cache_t),
......
This diff is collapsed.
...@@ -352,6 +352,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -352,6 +352,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(bool, void* ptr) explicit FP32Vec16(bool, void* ptr)
: reg((__m512)_mm512_stream_load_si512(ptr)) {} : reg((__m512)_mm512_stream_load_si512(ptr)) {}
// strided load
explicit FP32Vec16(const float* ptr, INT32Vec16 idx)
: reg(_mm512_i32gather_ps(idx.reg, ptr, 4)) {}
explicit FP32Vec16(__m512 data) : reg(data) {} explicit FP32Vec16(__m512 data) : reg(data) {}
// de-pack 4 bit values // de-pack 4 bit values
...@@ -408,6 +412,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -408,6 +412,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return FP32Vec16(_mm512_sub_ps(reg, b.reg)); return FP32Vec16(_mm512_sub_ps(reg, b.reg));
} }
FP32Vec16 operator-() const {
return FP32Vec16(_mm512_xor_ps(reg, _mm512_set1_ps(-0.0f)));
}
FP32Vec16 operator/(const FP32Vec16& b) const { FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(_mm512_div_ps(reg, b.reg)); return FP32Vec16(_mm512_div_ps(reg, b.reg));
} }
......
#include "cpu_types.hpp" #include "cpu/cpu_types.hpp"
#include "scratchpad_manager.h" #include "cpu/utils.hpp"
#include "utils.hpp"
#ifdef CPU_CAPABILITY_AMXBF16 #ifdef CPU_CAPABILITY_AMXBF16
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp" #include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
...@@ -158,7 +157,7 @@ void cpu_gemm_wna16_impl( ...@@ -158,7 +157,7 @@ void cpu_gemm_wna16_impl(
// a simple schedule policy, just to hold more B tiles in L2 and make sure // a simple schedule policy, just to hold more B tiles in L2 and make sure
// each thread has tasks // each thread has tasks
const int32_t n_partition_size = [&]() { const int32_t n_partition_size = [&]() {
const int64_t cache_size = cpu_utils::get_l2_size(); const int64_t cache_size = cpu_utils::get_available_l2_size();
int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t)); int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t));
int64_t ps_thread_limit = n_size / thread_num; int64_t ps_thread_limit = n_size / thread_num;
ps_cache_limit = ps_cache_limit =
...@@ -179,8 +178,8 @@ void cpu_gemm_wna16_impl( ...@@ -179,8 +178,8 @@ void cpu_gemm_wna16_impl(
const int64_t b_buffer_offset = 0; const int64_t b_buffer_offset = 0;
const int64_t c_buffer_offset = b_buffer_size; const int64_t c_buffer_offset = b_buffer_size;
const int64_t buffer_size = b_buffer_size + c_buffer_size; const int64_t buffer_size = b_buffer_size + c_buffer_size;
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size * cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(buffer_size *
thread_num); thread_num);
alignas(64) cpu_utils::Counter counter; alignas(64) cpu_utils::Counter counter;
cpu_utils::Counter* counter_ptr = &counter; cpu_utils::Counter* counter_ptr = &counter;
...@@ -190,9 +189,10 @@ void cpu_gemm_wna16_impl( ...@@ -190,9 +189,10 @@ void cpu_gemm_wna16_impl(
scalar_t* __restrict__ b_buffer = nullptr; scalar_t* __restrict__ b_buffer = nullptr;
float* __restrict__ c_buffer = nullptr; float* __restrict__ c_buffer = nullptr;
{ {
uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager() uint8_t* buffer_ptr =
->get_data<uint8_t>() + cpu_utils::ScratchPadManager::get_scratchpad_manager()
thread_id * buffer_size; ->get_data<uint8_t>() +
thread_id * buffer_size;
b_buffer = reinterpret_cast<scalar_t*>(buffer_ptr + b_buffer_offset); b_buffer = reinterpret_cast<scalar_t*>(buffer_ptr + b_buffer_offset);
c_buffer = reinterpret_cast<float*>(buffer_ptr + c_buffer_offset); c_buffer = reinterpret_cast<float*>(buffer_ptr + c_buffer_offset);
} }
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
#include "common/memory_desc.hpp" #include "common/memory_desc.hpp"
#include "common/memory.hpp" #include "common/memory.hpp"
#include "dnnl_helper.h" #include "cpu/utils.hpp"
#include "scratchpad_manager.h" #include "cpu/dnnl_helper.h"
static dnnl::engine& default_engine() { static dnnl::engine& default_engine() {
static dnnl::engine engine(dnnl::engine::kind::cpu, 0); static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
...@@ -274,7 +274,7 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) { ...@@ -274,7 +274,7 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5); auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5);
scratchpad_storage->set_data_handle( scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>()); cpu_utils::ScratchPadManager::get_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_); matmul.execute(default_stream(), memory_cache_);
default_stream().wait(); default_stream().wait();
...@@ -294,7 +294,7 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache( ...@@ -294,7 +294,7 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
return m_size_cache_->get_or_create(key, [&]() { return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); auto manager = cpu_utils::ScratchPadManager::get_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size()); manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc); return dnnl::matmul(desc);
}); });
...@@ -470,7 +470,7 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) { ...@@ -470,7 +470,7 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3); auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
scratchpad_storage->set_data_handle( scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>()); cpu_utils::ScratchPadManager::get_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_); matmul.execute(default_stream(), memory_cache_);
default_stream().wait(); default_stream().wait();
...@@ -486,7 +486,7 @@ dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache( ...@@ -486,7 +486,7 @@ dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
} }
return m_size_cache_->get_or_create(key, [&]() { return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); auto manager = cpu_utils::ScratchPadManager::get_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size()); manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc); return dnnl::matmul(desc);
}); });
......
...@@ -235,6 +235,39 @@ class MicroGemm<cpu_utils::ISA::AMX, scalar_t> { ...@@ -235,6 +235,39 @@ class MicroGemm<cpu_utils::ISA::AMX, scalar_t> {
} }
} }
static void pack_weight(const scalar_t* __restrict__ weight,
scalar_t* __restrict__ packed_weight,
const int32_t output_size, const int32_t input_size) {
constexpr int32_t elem_num_per_group = 4 / sizeof(scalar_t);
TORCH_CHECK_EQ(output_size % 16, 0);
TORCH_CHECK_EQ(input_size % (16 * elem_num_per_group), 0);
const int32_t output_group_num = output_size / 16;
const int32_t input_32b_num = input_size / elem_num_per_group;
for (int32_t output_group_idx = 0; output_group_idx < output_group_num;
++output_group_idx) {
const int32_t* __restrict__ weight_32b =
reinterpret_cast<const int32_t*>(weight);
int32_t* __restrict__ packed_weight_32b =
reinterpret_cast<int32_t*>(packed_weight);
for (int32_t output_idx = 0; output_idx < 16; ++output_idx) {
for (int32_t weight_offset = 0, packed_offset = 0;
weight_offset < input_32b_num;
++weight_offset, packed_offset += 16) {
packed_weight_32b[packed_offset] = weight_32b[weight_offset];
}
// update
weight_32b += input_32b_num;
packed_weight_32b += 1;
}
// update
weight += 16 * input_size;
packed_weight += 16 * input_size;
}
}
private: private:
alignas(64) __tilecfg amx_tile_config_; alignas(64) __tilecfg amx_tile_config_;
int32_t curr_m_; int32_t curr_m_;
......
...@@ -13,6 +13,9 @@ namespace cpu_micro_gemm { ...@@ -13,6 +13,9 @@ namespace cpu_micro_gemm {
#define CPU_MICRO_GEMM_PARAMS \ #define CPU_MICRO_GEMM_PARAMS \
a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
// Note: weights for MicroGemm should be packed as (output_size / 16) contiguous
// blocks, means the logical shape of blocks is [16, input_size]. And the actual
// layout of blocks can be ISA-specific.
template <cpu_utils::ISA isa, typename scalar_t> template <cpu_utils::ISA isa, typename scalar_t>
class MicroGemm { class MicroGemm {
public: public:
...@@ -86,6 +89,41 @@ FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr, ...@@ -86,6 +89,41 @@ FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
curr_d += ldd; curr_d += ldd;
} }
} }
template <int32_t n_size, typename scalar_t>
FORCE_INLINE void add_bias_epilogue(float* c_ptr, float* d_ptr,
scalar_t* __restrict__ bias_ptr,
const int32_t m, const int64_t ldc,
const int64_t ldd) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
static_assert(n_size % 16 == 0);
constexpr int32_t n_group_num = n_size / 16;
static_assert(n_group_num <= 16);
vec_op::FP32Vec16 bias_vecs[n_group_num];
scalar_t* __restrict__ curr_bias = bias_ptr;
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t i) {
scalar_vec_t vec(curr_bias);
bias_vecs[i] = vec_op::FP32Vec16(vec);
curr_bias += 16;
});
float* curr_c = c_ptr;
float* curr_d = d_ptr;
for (int32_t i = 0; i < m; ++i) {
float* curr_c_iter = curr_c;
float* curr_d_iter = curr_d;
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t n_g_idx) {
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx];
c_vec_fp32.save(curr_d_iter);
curr_c_iter += 16;
curr_d_iter += 16;
});
curr_c += ldc;
curr_d += ldd;
}
}
} // namespace cpu_micro_gemm } // namespace cpu_micro_gemm
#endif #endif
...@@ -109,6 +109,25 @@ class MicroGemm<cpu_utils::ISA::VEC, scalar_t> { ...@@ -109,6 +109,25 @@ class MicroGemm<cpu_utils::ISA::VEC, scalar_t> {
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) { void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
TileGemm82<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS); TileGemm82<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
} }
// Note: pack contiguous weight [output_size, input_size] as contiguous
// packed weight [output_size / 16, input_size, 16]
static void pack_weight(const scalar_t* __restrict__ weight,
scalar_t* __restrict__ packed_weight,
const int32_t output_size, const int32_t input_size) {
TORCH_CHECK_EQ(output_size % 16, 0);
for (int32_t o_idx = 0; o_idx < output_size; ++o_idx) {
const scalar_t* __restrict__ curr_weight = weight + o_idx * input_size;
scalar_t* __restrict__ curr_packed_weight =
packed_weight + (o_idx / 16) * (16 * input_size) + o_idx % 16;
for (int32_t i_idx = 0; i_idx < input_size; ++i_idx) {
*curr_packed_weight = *curr_weight;
curr_packed_weight += 16;
++curr_weight;
}
}
}
}; };
} // namespace cpu_micro_gemm } // namespace cpu_micro_gemm
......
#include <cstdlib>
#include "scratchpad_manager.h"
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void DNNLScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
if (ptr_ != nullptr) {
std::free(ptr_);
}
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
static DNNLScratchPadManager manager;
return &manager;
}
#ifndef SCRATCHPAD_MANAGER_H
#define SCRATCHPAD_MANAGER_H
#include <cstddef>
#include <cstdio>
class DNNLScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
DNNLScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
#endif
...@@ -110,6 +110,17 @@ void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight, ...@@ -110,6 +110,17 @@ void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
const std::optional<torch::Tensor>& bias, const std::optional<torch::Tensor>& bias,
const int64_t pack_factor, const std::string& isa_hint); const int64_t pack_factor, const std::string& isa_hint);
void prepack_moe_weight(const torch::Tensor& weight,
torch::Tensor& packed_weight, const std::string& isa);
void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
const torch::Tensor& w13, const torch::Tensor& w2,
const std::optional<torch::Tensor>& w13_bias,
const std::optional<torch::Tensor>& w2_bias,
const torch::Tensor& topk_weights,
const torch::Tensor& topk_id, const std::string& act,
const std::string& isa);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops // vLLM custom ops
...@@ -296,6 +307,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -296,6 +307,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"pack_factor, str isa_hint) -> ()"); "pack_factor, str isa_hint) -> ()");
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16); ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
#endif #endif
// fused moe
#if defined(__AVX512F__)
ops.def(
"prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) "
"-> ()");
ops.impl("prepack_moe_weight", torch::kCPU, &prepack_moe_weight);
ops.def(
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"str act, str isa) -> ()");
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif
} }
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#define gettid() syscall(SYS_gettid) #define gettid() syscall(SYS_gettid)
#endif #endif
#include "cpu_types.hpp" #include "cpu/utils.hpp"
#ifdef VLLM_NUMA_DISABLED #ifdef VLLM_NUMA_DISABLED
std::string init_cpu_threads_env(const std::string& cpu_ids) { std::string init_cpu_threads_env(const std::string& cpu_ids) {
...@@ -138,4 +138,26 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { ...@@ -138,4 +138,26 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
return ss.str(); return ss.str();
} }
#endif #endif // VLLM_NUMA_DISABLED
namespace cpu_utils {
ScratchPadManager::ScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void ScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
if (ptr_ != nullptr) {
std::free(ptr_);
}
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
ScratchPadManager* ScratchPadManager::get_scratchpad_manager() {
static ScratchPadManager manager;
return &manager;
}
} // namespace cpu_utils
...@@ -2,19 +2,24 @@ ...@@ -2,19 +2,24 @@
#define UTILS_HPP #define UTILS_HPP
#include <atomic> #include <atomic>
#include <cassert>
#include <cstdint>
#include <unistd.h> #include <unistd.h>
#include <ATen/cpu/Utils.h>
#if defined(__APPLE__) #include "cpu/cpu_types.hpp"
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
namespace cpu_utils { namespace cpu_utils {
enum class ISA { AMX, VEC }; enum class ISA { AMX, VEC };
inline ISA get_isa(const std::string& isa) {
if (isa == "amx") {
return ISA::AMX;
} else if (isa == "vec") {
return ISA::VEC;
} else {
TORCH_CHECK(false, "Invalid isa type: " + isa);
}
}
template <typename T> template <typename T>
struct VecTypeTrait { struct VecTypeTrait {
using vec_t = void; using vec_t = void;
...@@ -48,26 +53,66 @@ struct Counter { ...@@ -48,26 +53,66 @@ struct Counter {
int64_t acquire_counter() { return counter++; } int64_t acquire_counter() { return counter++; }
}; };
inline int64_t get_l2_size() { inline int64_t get_available_l2_size() {
static int64_t size = []() { static int64_t size = []() {
#if defined(__APPLE__) const uint32_t l2_cache_size = at::cpu::L2_cache_size();
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t l2_cache_size = 0;
size_t len = sizeof(l2_cache_size);
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
l2_cache_size > 0) {
return l2_cache_size >> 1; // use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return 128LL * 1024 >> 1; // use 50% of 128KB
#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
assert(l2_cache_size != -1);
return l2_cache_size >> 1; // use 50% of L2 cache return l2_cache_size >> 1; // use 50% of L2 cache
#endif
}(); }();
return size; return size;
} }
template <int32_t alignment_v, typename T>
inline T round_up(T size) {
T alignment = alignment_v;
return (((size + alignment - 1) / alignment) * alignment);
}
template <int32_t alignment_v, typename T>
inline T round_down(T size) {
T alignment = alignment_v;
return (size / alignment) * alignment;
}
template <typename T>
inline void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
int32_t stride) {
std::stringstream ss;
ss << std::fixed << std::setprecision(5) << name << ": [\n";
auto* curr_logits_buffer = ptr;
for (int32_t m = 0; m < row; ++m) {
for (int32_t n = 0; n < col; ++n) {
ss << curr_logits_buffer[n] << ", ";
}
ss << "\n";
curr_logits_buffer += stride;
}
ss << "]\n";
std::printf("%s", ss.str().c_str());
}
class ScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
static ScratchPadManager* get_scratchpad_manager();
ScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
} // namespace cpu_utils } // namespace cpu_utils
#endif #endif
...@@ -107,6 +107,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem, ...@@ -107,6 +107,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
prop.location.id = device; prop.location.id = device;
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
#ifndef USE_ROCM
int flag = 0;
CUDA_CHECK(cuDeviceGetAttribute(
&flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED,
device));
if (flag) { // support GPUDirect RDMA if possible
prop.allocFlags.gpuDirectRDMACapable = 1;
}
#endif
#ifndef USE_ROCM #ifndef USE_ROCM
// Allocate memory using cuMemCreate // Allocate memory using cuMemCreate
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0)); CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
......
...@@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) { ...@@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) {
template <ScoringFunc SF, typename T> template <ScoringFunc SF, typename T>
__device__ inline T apply_scoring(T val) { __device__ inline T apply_scoring(T val) {
if constexpr (SF == SCORING_SIGMOID) { if constexpr (SF == SCORING_NONE) {
return val;
} else if constexpr (SF == SCORING_SIGMOID) {
return apply_sigmoid(val); return apply_sigmoid(val);
} else { } else {
static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID,
"Unsupported ScoringFunc in apply_scoring");
return val; return val;
} }
} }
...@@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel(
if (case_id < num_tokens) { if (case_id < num_tokens) {
if (if_proceed_next_topk) { if (if_proceed_next_topk) {
float scale = routed_scaling_factor;
if (renormalize) {
scale /= topk_sum;
}
for (int i = lane_id; i < topk; i += WARP_SIZE) { for (int i = lane_id; i < topk; i += WARP_SIZE) {
float base = cuda_cast<float, T>(s_topk_value[i]); float base = cuda_cast<float, T>(s_topk_value[i]);
float value = renormalize ? (base / topk_sum * routed_scaling_factor) float value = base * scale;
: (base * routed_scaling_factor);
topk_indices[i] = s_topk_idx[i]; topk_indices[i] = s_topk_idx[i];
topk_values[i] = value; topk_values[i] = value;
} }
......
sm*_kernel_*.cu sm*_kernel_*.cu
kernel_selector.h kernel_selector.h
kernel_*.cu
...@@ -10,6 +10,8 @@ import jinja2 ...@@ -10,6 +10,8 @@ import jinja2
ARCHS = [] ARCHS = []
SUPPORT_FP8 = False SUPPORT_FP8 = False
SUPPORT_SM75 = False
SUPPORT_SM80 = False
for arch in sys.argv[1].split(","): for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "") arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch) arch = int(arch)
...@@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","): ...@@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
# with FP16 MMA, so it cannot achieve any acceleration. # with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]: if arch in [89, 120]:
SUPPORT_FP8 = True SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
if arch == 75:
SUPPORT_SM75 = True
FILE_HEAD_COMMENT = """ FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py // auto generated by generate_kernels.py
...@@ -157,6 +163,7 @@ def remove_old_kernels(): ...@@ -157,6 +163,7 @@ def remove_old_kernels():
def generate_new_kernels(): def generate_new_kernels():
result_dict = {} result_dict = {}
sm_75_result_dict = {}
for quant_config in QUANT_CONFIGS: for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
...@@ -174,6 +181,8 @@ def generate_new_kernels(): ...@@ -174,6 +181,8 @@ def generate_new_kernels():
s_type = quant_config.get("s_type", c_type) s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict: if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = [] result_dict[(a_type, b_type, c_type)] = []
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
sm_75_result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product( for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs all_group_blocks, all_m_blocks, all_thread_configs
...@@ -197,78 +206,89 @@ def generate_new_kernels(): ...@@ -197,78 +206,89 @@ def generate_new_kernels():
"thread_k_blocks": thread_k // 16, "thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16, "thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false", "m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages", "stages": 4,
"group_blocks": group_blocks, "group_blocks": group_blocks,
"is_zp_float": "false", "is_zp_float": "false",
} }
result_dict[(a_type, b_type, c_type)].append(config) if SUPPORT_SM80:
result_dict[(a_type, b_type, c_type)].append(config)
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
config_sm75 = config.copy()
config_sm75["stages"] = 2
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
kernel_selector_str = FILE_HEAD_COMMENT kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items(): for result_dict_tmp in [result_dict, sm_75_result_dict]:
all_template_str_list = [] for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
for config in config_list: all_template_str_list = []
s_type = config["s_type"] if not config_list:
template_str = jinja2.Template(TEMPLATE).render( continue
a_type_id=f"vllm::{a_type}.id()", for config in config_list:
b_type_id=f"vllm::{b_type}.id()", s_type = config["s_type"]
c_type_id=f"vllm::{c_type}.id()", template_str = jinja2.Template(TEMPLATE).render(
s_type_id=f"vllm::{s_type}.id()",
**config,
)
all_template_str_list.append(template_str)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()", a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()", b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()", c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()", s_type_id=f"vllm::{s_type}.id()",
**config, **config,
) )
+ "\n" all_template_str_list.append(template_str)
)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"stages == {config['stages']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
file_content = FILE_HEAD + "\n\n" kernel_selector_str += (
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" jinja2.Template(kernel_template2).render(
if a_type == "kFE4M3fn": a_type_id=f"vllm::{a_type}.id()",
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" b_type_id=f"vllm::{b_type}.id()",
else: c_type_id=f"vllm::{c_type}.id()",
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
elif result_dict_tmp is sm_75_result_dict:
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower() filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content) f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += ( kernel_selector_str += (
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "quantization/gptq_marlin/marlin.cuh" #include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h" #include "quantization/gptq_marlin/dequant.h"
#include "quantization/gptq_marlin/marlin_mma.h"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...@@ -35,7 +36,7 @@ ...@@ -35,7 +36,7 @@
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
...@@ -84,146 +85,6 @@ __global__ void Marlin( ...@@ -84,146 +85,6 @@ __global__ void Marlin(
#else #else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma_trans(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
asm volatile(
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#else
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#endif
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared // Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout. // memory, directly in tensor core layout.
template <int count, vllm::ScalarTypeId type_id> template <int count, vllm::ScalarTypeId type_id>
...@@ -439,9 +300,20 @@ __global__ void Marlin( ...@@ -439,9 +300,20 @@ __global__ void Marlin(
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return; if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
#endif #endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
// Turing TensorCore only supports fp16 and int8
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
return;
#endif
int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
#else
constexpr bool use_fp16_accum = false;
#endif
using Adtype = MarlinScalarType<a_type_id>; using Adtype = MarlinScalarType<a_type_id>;
using Cdtype = MarlinScalarType<c_type_id>; using Cdtype = MarlinScalarType<c_type_id>;
...@@ -618,7 +490,22 @@ __global__ void Marlin( ...@@ -618,7 +490,22 @@ __global__ void Marlin(
} }
} }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
if constexpr (moe_block_size >= 16)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16);
if constexpr (moe_block_size >= 8)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8);
if constexpr (moe_block_size >= 4)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4);
if constexpr (moe_block_size >= 2)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2);
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1);
block_num_valid_tokens = local_count;
#else
block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count); block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count);
#endif
if (lane_id == 0) if (lane_id == 0)
reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens; reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens;
...@@ -1018,10 +905,6 @@ __global__ void Marlin( ...@@ -1018,10 +905,6 @@ __global__ void Marlin(
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage); : (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage);
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size; int4* sh_a = sh_s + sh_s_size;
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
...@@ -1545,11 +1428,13 @@ __global__ void Marlin( ...@@ -1545,11 +1428,13 @@ __global__ void Marlin(
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1, mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]); frag_c[i][j][0]);
} else { } else {
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]); mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]); frag_c[i][j][0]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
frag_c[i][j][1]);
} }
} }
} }
...@@ -1583,10 +1468,12 @@ __global__ void Marlin( ...@@ -1583,10 +1468,12 @@ __global__ void Marlin(
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0], mma<a_type_id, false, 32>(
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); frag_a[k2][i], frag_b[0],
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1], (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
} }
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
...@@ -2132,6 +2019,21 @@ __global__ void Marlin( ...@@ -2132,6 +2019,21 @@ __global__ void Marlin(
// While this pattern may not be the most readable, other ways of writing // While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation. // the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) { if (slice_iters == 0) {
// convert fp16 accum to fp32 for reduction
if constexpr (use_fp16_accum) {
#pragma unroll
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
scalar_t* frag_c_part_half =
reinterpret_cast<scalar_t*>(frag_c_part_float);
#pragma unroll
for (int i = 3; i >= 0; i--) {
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
}
}
}
if constexpr (is_a_8bit) { if constexpr (is_a_8bit) {
float frag_a_s[2 * thread_m_blocks]; float frag_a_s[2 * thread_m_blocks];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment