Unverified Commit 5dd62c3a authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

Add fp8 shared_expert kernel for CPU in sgl-kernel and add UT (#6339)


Co-authored-by: default avatarJiang, Yanbing <yanbing.jiang@intel.com>
Co-authored-by: default avatarmingfeima <mingfei.ma@intel.com>
parent f11481b9
......@@ -104,6 +104,24 @@ void shared_expert_int8_kernel_impl(
int64_t N,
int64_t K);
template <typename scalar_t>
void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K);
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
......@@ -134,3 +152,20 @@ void tinygemm_kernel(
int64_t ldb,
int64_t ldc,
bool brg);
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K);
......@@ -248,38 +248,6 @@ struct brgemm {
}
};
template <typename scalar_t, bool has_bias>
struct brgemm<scalar_t, scalar_t, has_bias> {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc) {
UNUSED(scale);
constexpr int BLOCK_N = block_size_n();
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
// copy from Ctmp to C
for (int m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <bool has_bias>
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
static inline void apply(
......@@ -469,6 +437,46 @@ void fp8_scaled_mm_kernel_impl(
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K) {
tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const TYPE* __restrict__ A, \
const at::Float8_e4m3fn* __restrict__ B, \
TYPE* __restrict__ C, \
TYPE* __restrict__ Btmp, \
float* __restrict__ Ctmp, \
const float* __restrict__ scale, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t lda, \
int64_t ldb, \
int64_t ldc, \
bool brg, \
int64_t block_size_K)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
......
......@@ -1137,8 +1137,10 @@ at::Tensor shared_expert_cpu(
double routed_scaling_factor,
bool inplace,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor>& w1_scale,
std::optional<at::Tensor>& w2_scale,
std::optional<std::vector<int64_t>> block_size,
std::optional<at::Tensor>& a1_scale,
std::optional<at::Tensor>& a2_scale,
bool is_vnni) {
......@@ -1180,6 +1182,11 @@ at::Tensor shared_expert_cpu(
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
}
if (use_fp8_w8a16) {
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16.");
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16.");
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
}
at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states);
......@@ -1191,12 +1198,18 @@ at::Tensor shared_expert_cpu(
// 3. Aq_tmp : [M, K] or [M, N]
// 4. As_tmp : [M]
//
// for fp8 w8a16:
// 5. intermediate_cache0 : [M, 2N]
//
int num_threads = at::get_num_threads();
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
if (use_int8_w8a8) {
buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float);
}
if (use_fp8_w8a16) {
buffer_size_nbytes += M * 2 * N * 2;
}
auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] {
......@@ -1228,6 +1241,36 @@ at::Tensor shared_expert_cpu(
M,
N,
K);
} else if (use_fp8_w8a16) {
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
auto w1s = w1_scale.value();
auto w2s = w2_scale.value();
auto block_size_val = block_size.value();
TORCH_CHECK(block_size_val.size() == 2, "shared_expert: expect block_size.size() to be 2.");
int64_t block_size_N = block_size_val[0];
int64_t block_size_K = block_size_val[1];
TORCH_CHECK(w1s.size(0) == 2 * N / block_size_N);
TORCH_CHECK(w1s.size(1) == K / block_size_K);
TORCH_CHECK(w2s.size(0) == K / block_size_N);
TORCH_CHECK(w2s.size(1) == N / block_size_K);
shared_expert_fp8_kernel_impl<scalar_t>(
out_hidden_states.data_ptr<scalar_t>(),
intermediate_cache0,
intermediate_cache1,
hidden_states.data_ptr<scalar_t>(),
packed_w1.data_ptr<at::Float8_e4m3fn>(),
packed_w2.data_ptr<at::Float8_e4m3fn>(),
w1s.data_ptr<float>(),
w2s.data_ptr<float>(),
block_size_N,
block_size_K,
fused_experts_out.data_ptr<scalar_t>(),
routed_scaling_factor,
M,
N,
K);
} else {
shared_expert_kernel_impl<scalar_t>(
out_hidden_states.data_ptr<scalar_t>(),
......
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
// out = input + input2 * scale
template <typename scalar_t>
inline void add_mul_stub(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ input2,
float scale,
int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec s_vec = fVec(scale);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input + d);
fVec x0, x1;
std::tie(x0, x1) = at::vec::convert_to_float(x_bvec);
bVec y_bvec = bVec::loadu(input2 + d);
fVec y0, y1;
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
x0 = x0 + y0 * s_vec;
x1 = x1 + y1 * s_vec;
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
}
}
template <typename scalar_t>
inline void silu_and_mul_stub(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const scalar_t* __restrict__ input2, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
const fVec one = fVec(1.f);
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += bVec::size()) {
bVec x = bVec::loadu(input + d);
fVec x0, x1;
std::tie(x0, x1) = at::vec::convert_to_float(x);
bVec y = bVec::loadu(input2 + d);
fVec y0, y1;
std::tie(y0, y1) = at::vec::convert_to_float(y);
x0 = x0 / (one + x0.neg().exp_u20());
x1 = x1 / (one + x1.neg().exp_u20());
x0 = x0 * y0;
x1 = x1 * y1;
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(out + d);
}
}
} // anonymous namespace
template <typename scalar_t>
void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
// stage 1: intermediate_cache0 = hidden_states @ w1
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(2 * N, BLOCK_N);
int64_t scale_size_K = div_up(K, block_size_K);
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
int64_t mb_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t nb_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
tinygemm_kernel<scalar_t>(
/* A */ input + mb * BLOCK_M * K,
/* B */ packed_w1 + nb * BLOCK_N * K,
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
/* Btmp */ Btmp,
/* Ctmp */ Ctmp,
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ 2 * N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
silu_and_mul_stub(ic1 + m * N, ic0 + m * 2 * N, ic0 + m * 2 * N + N, N);
}
});
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
// w2 : [K, N] as [OC, IC]
const int64_t OC = K; // rename K as OC
const int64_t IC = N; // rename N as IC
const int64_t MB2 = MB;
const int64_t NB2 = div_up(K, BLOCK_N);
scale_size_K = div_up(N, block_size_K);
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N];
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
alignas(64) float Ctmp[BLOCK_M * BLOCK_K];
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t mb_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t nb_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
// 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>(
/* A */ ic1 + mb * BLOCK_M * N,
/* B */ packed_w2 + nb * BLOCK_N * N,
/* C */ C,
/* Btmp */ Btmp,
/* Ctmp */ Ctmp,
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ mb_size,
/* N */ nb_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ nb_size,
/* ldc */ BLOCK_N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
// 2.b copy from C to output and add fused_experts_out
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
for (int64_t m = 0; m < mb_size; ++m) {
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, nb_size);
}
}
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
}
#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \
template void shared_expert_fp8_kernel_impl<TYPE>( \
TYPE* __restrict__ output, \
TYPE* __restrict__ ic0, \
TYPE* __restrict__ ic1, \
const TYPE* __restrict__ input, \
const at::Float8_e4m3fn* __restrict__ packed_w1, \
const at::Float8_e4m3fn* __restrict__ packed_w2, \
const float* __restrict__ w1s, \
const float* __restrict__ w2s, \
int64_t block_size_N, \
int64_t block_size_K, \
const TYPE* __restrict__ fused_experts_out, \
float routed_scaling_factor, \
int64_t M, \
int64_t N, \
int64_t K)
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16);
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half);
......@@ -139,8 +139,10 @@ at::Tensor shared_expert_cpu(
double routed_scaling_factor,
bool inplace,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor>& w1_scale,
std::optional<at::Tensor>& w2_scale,
std::optional<std::vector<int64_t>> block_size,
std::optional<at::Tensor>& a1_scale,
std::optional<at::Tensor>& a2_scale,
bool is_vnni);
......
......@@ -61,6 +61,7 @@ sources = [
"csrc/cpu/gemm_fp8.cpp",
"csrc/cpu/gemm_int8.cpp",
"csrc/cpu/moe.cpp",
"csrc/cpu/moe_fp8.cpp",
"csrc/cpu/moe_int8.cpp",
"csrc/cpu/norm.cpp",
"csrc/cpu/qkv_proj.cpp",
......
import itertools
import math
import unittest
import torch
import torch.nn as nn
# TODO: use interface in cpu.py
from sgl_kernel.common_ops import convert_weight_packed
from sgl_kernel.common_ops import shared_expert_cpu as shared_expert
from utils import (
BLOCK_K,
BLOCK_N,
SiluAndMul,
factor_for_scale,
fp8_max,
fp8_min,
per_token_quant_int8,
precision,
scaled_weight,
torch_naive_moe,
torch_w8a8_per_column_moe,
)
from sglang.test.test_utils import CustomTestCase
class TestSharedExpert(CustomTestCase):
M = [2, 121]
N = [32, 32 * 4]
K = [32, 32 * 2]
routed_scaling_factor = [16]
M_fp8 = [2, 12]
N_fp8 = [512]
K_fp8 = [256]
def _bf16_shared_expert(self, m, n, k, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
hidden_states = torch.randn(m, k, dtype=dtype) / k
w1 = torch.randn(2 * n, k, dtype=dtype)
w2 = torch.randn(k, n, dtype=dtype)
fused_output = torch.randn(m, k, dtype=dtype) / k
# fused moe mutates content in hs
hidden_states2 = hidden_states.clone()
# bfloat16
ref = torch_naive_moe(
hidden_states.float(),
w1.float(),
w2.float(),
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res = shared_expert(
hidden_states,
w1,
w2,
fused_output,
routed_scaling_factor,
True,
False,
False,
None,
None,
None,
None,
None,
False,
)
atol = rtol = precision[ref.dtype]
self.assertTrue(torch.allclose(ref, res, atol=atol, rtol=rtol))
def test_bf16_shared_expert(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.routed_scaling_factor,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
routed_scaling_factor=params[3],
):
self._bf16_shared_expert(*params)
def _int8_shared_expert(self, m, n, k, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
hidden_states = torch.randn(m, k, dtype=dtype) / k
w1 = torch.randn(2 * n, k, dtype=dtype)
w2 = torch.randn(k, n, dtype=dtype)
fused_output = torch.randn(m, k, dtype=dtype) / k
# fused moe mutates content in hs
hidden_states2 = hidden_states.clone()
w1_q, w1_s = per_token_quant_int8(w1)
w2_q, w2_s = per_token_quant_int8(w2)
ref2 = torch_w8a8_per_column_moe(
hidden_states2.float(),
w1_q,
w2_q,
w1_s,
w2_s,
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res2 = shared_expert(
hidden_states2,
w1_q,
w2_q,
fused_output,
routed_scaling_factor,
True,
True,
False,
w1_s,
w2_s,
None,
None,
None,
False,
)
atol = rtol = precision[ref2.dtype]
self.assertTrue(torch.allclose(ref2, res2, atol=atol, rtol=rtol))
def test_int8_shared_expert(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.routed_scaling_factor,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
routed_scaling_factor=params[3],
):
self._int8_shared_expert(*params)
def _fp8_shared_expert(self, M, N, K, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
w1_fp32 = torch.randn(1, 2 * N, K)
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = torch.randn(1, K, N)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w1s = torch.randn(1, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
w2s = torch.randn(1, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
w1_scaled = scaled_weight(w1, w1s).view(2 * N, K)
w2_scaled = scaled_weight(w2, w2s).view(K, N)
# change back to 2D
w1, w2 = w1.squeeze(0), w2.squeeze(0)
w1s, w2s = w1s.squeeze(0), w2s.squeeze(0)
w1_scaled, w2_scaled = w1_scaled.squeeze(0), w2_scaled.squeeze(0)
fused_out = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
a2 = a.clone()
# ref
ic0 = torch.matmul(a.float(), w1_scaled.transpose(0, 1))
ic1 = SiluAndMul(ic0)
shared_out = torch.matmul(ic1, w2_scaled.transpose(0, 1))
ref_out = shared_out + fused_out.float() * routed_scaling_factor
ref_out = ref_out.to(dtype=dtype)
w1 = convert_weight_packed(w1) # [2N, K]
w2 = convert_weight_packed(w2) # [K, N]
out = shared_expert(
a2,
w1,
w2,
fused_out,
routed_scaling_factor,
True,
False,
True,
w1s,
w2s,
[BLOCK_N, BLOCK_K],
None,
None,
True,
)
atol = rtol = precision[ref_out.dtype]
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
def test_fp8_shared_expert(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.routed_scaling_factor,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
routed_scaling_factor=params[3],
):
self._fp8_shared_expert(*params)
if __name__ == "__main__":
unittest.main()
import math
import torch
import torch.nn.functional as F
precision = {
torch.bfloat16: 1e-2,
......@@ -9,6 +10,16 @@ precision = {
}
BLOCK_N, BLOCK_K = 64, 128
factor_for_scale = 1e-3
fp8_max, fp8_min = 400, -400
def SiluAndMul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def per_token_quant_int8(x):
x = x.float()
absmax = x.abs().max(dim=-1).values
......@@ -94,3 +105,46 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16
C.add_(bias.view(1, -1))
return C.reshape(origin_C_shape).to(output_dtype)
def torch_naive_moe(a, w1, w2, b, routed_scaling_factor):
ic1 = torch.matmul(a, w1.transpose(0, 1))
ic2 = SiluAndMul(ic1)
ic3 = torch.matmul(ic2, w2.transpose(0, 1))
return ic3 + b * routed_scaling_factor
def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_factor):
# Perform per-token quantization
a_q, a_s = per_token_quant_int8(a)
ic1 = native_w8a8_per_token_matmul(
a_q, w1_q, a_s, w1_s, bias=None, output_dtype=torch.float32
)
ic2 = SiluAndMul(ic1)
a1_q, a1_s = per_token_quant_int8(ic2)
ic3 = native_w8a8_per_token_matmul(
a1_q, w2_q, a1_s, w2_s, bias=None, output_dtype=torch.float32
)
return ic3 + b * routed_scaling_factor
def scaled_weight(weight, scales):
E, N, K = weight.shape
weight_block = (
weight.view(E, N // BLOCK_N, BLOCK_N, K // BLOCK_K, BLOCK_K)
.permute(0, 1, 3, 2, 4)
.float()
.contiguous()
)
return (
(weight_block * scales.view(E, N // BLOCK_N, K // BLOCK_K, 1, 1))
.permute(0, 1, 3, 2, 4)
.contiguous()
.view(E, N, K)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment