Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
...@@ -36,6 +36,10 @@ struct ConvParamsBase { ...@@ -36,6 +36,10 @@ struct ConvParamsBase {
void *__restrict__ conv_state_ptr; void *__restrict__ conv_state_ptr;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t *__restrict__ conv_state_indices_ptr;
void *__restrict__ seq_idx_ptr; void *__restrict__ seq_idx_ptr;
// No __restrict__ since initial_states could be the same as final_states. // No __restrict__ since initial_states could be the same as final_states.
......
...@@ -586,7 +586,7 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ...@@ -586,7 +586,7 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
selective_scan_fwd_cuda<input_t, weight_t>(params, stream); selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
}); });
std::vector<at::Tensor> result = {out, x.value()}; std::vector<at::Tensor> result = {out};
if (has_z) { result.push_back(out_z); } if (has_z) { result.push_back(out_z); }
return result; return result;
} }
......
This diff is collapsed.
#include "marlin_moe_kernel_ku4b8.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4b8(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
bool replicate_input, bool apply_weights, int m_block, int max_par,
int cfg_max_m_blocks) {
if (false) {
}
GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4b8(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
bool replicate_input, bool apply_weights, int m_block, int max_par,
int cfg_max_m_blocks);
} // namespace marlin_moe
#include "marlin_moe_kernel_ku8b128.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku8b128(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
bool replicate_input, bool apply_weights, int m_block, int max_par,
int cfg_max_m_blocks) {
if (false) {
}
GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
bool call_marlin_moe_kernel_ku8b128(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
bool replicate_input, bool apply_weights, int m_block, int max_par,
int cfg_max_m_blocks);
}
This diff is collapsed.
...@@ -2,11 +2,14 @@ ...@@ -2,11 +2,14 @@
#include <torch/all.h> #include <torch/all.h>
#include "core/scalar_type.hpp"
torch::Tensor marlin_gemm_moe( torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm, const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights); bool replicate_input, bool apply_weights);
...@@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " "g_idx, Tensor! perm, Tensor! workspace, "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"bool replicate_input, bool apply_weights) -> Tensor"); "int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
#endif #endif
} }
......
...@@ -155,6 +155,8 @@ torch::Tensor prepack_B(torch::Tensor const& B, ...@@ -155,6 +155,8 @@ torch::Tensor prepack_B(torch::Tensor const& B,
}; // namespace machete }; // namespace machete
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta, torch::Tensor& b_meta,
torch::Tensor& b_scales, torch::Tensor& b_scales,
...@@ -226,10 +228,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, ...@@ -226,10 +228,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
#endif #endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale); torch::Tensor const& scale,
c10::optional<torch::Tensor> const& azp);
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales); torch::Tensor& scales,
c10::optional<torch::Tensor> const& azp);
// torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, // torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
// torch::Tensor b_gptq_qzeros, // torch::Tensor b_gptq_qzeros,
...@@ -262,11 +266,10 @@ std::vector<torch::Tensor> selective_scan_fwd( ...@@ -262,11 +266,10 @@ std::vector<torch::Tensor> selective_scan_fwd(
const c10::optional<torch::Tensor>& index_, const c10::optional<torch::Tensor>& index_,
const c10::optional<torch::Tensor>& x); const c10::optional<torch::Tensor>& x);
at::Tensor causal_conv1d_update(const at::Tensor& x, at::Tensor causal_conv1d_update(
const at::Tensor& conv_state, const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
const at::Tensor& weight, const c10::optional<at::Tensor>& bias, bool silu_activation,
const c10::optional<at::Tensor>& bias_, const c10::optional<at::Tensor>& conv_state_indices);
bool silu_activation);
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_, const c10::optional<at::Tensor>& bias_,
...@@ -281,8 +284,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, ...@@ -281,8 +284,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank, const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink); bool full_nvlink);
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out); torch::Tensor& out);
......
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
static constexpr int default_threads = 256;
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
// Currently only supports 16bit types (since we permute half types)
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
finish_row = size_m;
}
int cur_block_rows = std::max(finish_row - start_row, 0);
int row_stride = size_k * sizeof(half) / 16;
auto permute_row = [&](int row) {
int iters = size_k / default_threads;
int rest = size_k % default_threads;
int offset = row * row_stride;
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
base_k += default_threads;
}
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
}
}
};
for (int i = 0; i < cur_block_rows; i++) {
int cur_row = start_row + i;
if (cur_row < size_m) {
permute_row(cur_row);
}
}
}
// More efficient version of A[..., perm]
// taken from gptq_marlin.cu
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
auto dev = A.get_device();
auto stream = at::cuda::getCurrentCUDAStream(dev);
TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
"Currently only 16bit types are supported");
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
TORCH_CHECK(A.size(-1) % 8 == 0,
"A columns must be a multiple of 8 (128bits)");
auto A_2d = A.view({-1, A.size(-1)});
torch::Tensor D = torch::empty_like(A);
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
int block_rows = div_ceil(A_2d.size(0), sms);
permute_cols_kernel<<<sms, default_threads, 0, stream>>>(
reinterpret_cast<int4 const*>(A_2d.const_data_ptr()),
perm.const_data_ptr<int>(), reinterpret_cast<int4*>(D.mutable_data_ptr()),
A_2d.size(0), A_2d.size(1), block_rows);
return D;
}
\ No newline at end of file
...@@ -14,12 +14,17 @@ ...@@ -14,12 +14,17 @@
static inline __device__ int8_t float_to_int8_rn(float x) { static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM #ifdef USE_ROCM
static const float i8_min = static constexpr auto i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min()); static_cast<float>(std::numeric_limits<int8_t>::min());
static const float i8_max = static constexpr auto i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max()); static_cast<float>(std::numeric_limits<int8_t>::max());
// round
// To match the rounding mode of CUDA, we use nearbyint.
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
// If that changes in the future, we may need to set the rounding mode
// explicitly, either at runtime or compile time.
float dst = std::nearbyint(x); float dst = std::nearbyint(x);
// saturate // saturate
dst = std::clamp(dst, i8_min, i8_max); dst = std::clamp(dst, i8_min, i8_max);
return static_cast<int8_t>(dst); return static_cast<int8_t>(dst);
...@@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) { ...@@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
#endif #endif
} }
static inline __device__ int32_t float_to_int32_rn(float x) {
#ifdef USE_ROCM
// int32_max is not exactly representable as float.
// Therefore, we need to be careful and manually return int32_max on overflow.
// For symmetry, we also do the same for int32_min, even though it is exactly
// representable as float and the conversion should be exact.
static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
static constexpr auto i32_min_f = static_cast<float>(i32_min);
static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
static constexpr auto i32_max_f = static_cast<float>(i32_max);
// To match the rounding mode of CUDA, we use nearbyint.
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
// If that changes in the future, we may need to set the rounding mode
// explicitly, either at runtime or compile time.
float dst = std::nearbyint(x);
// saturate on the higher end.
if (dst >= i32_max_f) {
return i32_max;
}
// saturate on the lower end.
if (dst <= i32_min_f) {
return i32_min;
}
return static_cast<int32_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int32_t&>(dst);
#endif
}
static inline __device__ int8_t int32_to_int8(int32_t x) {
#ifdef USE_ROCM
static constexpr auto i8_min =
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
static constexpr auto i8_max =
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
// saturate
int32_t dst = std::clamp(x, i8_min, i8_max);
return static_cast<int8_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x));
return reinterpret_cast<const int8_t&>(dst);
#endif
}
namespace vllm { namespace vllm {
template <typename scalar_t, typename scale_type> template <typename scalar_t, typename scale_type>
...@@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel( ...@@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel(
} }
} }
template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void static_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, azp_type const* azp_ptr,
const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
azp_type const azp = *azp_ptr;
for (int i = tid; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
out[token_idx * hidden_size + i] = quant_val;
}
}
template <typename scalar_t, typename scale_type> template <typename scalar_t, typename scale_type>
__global__ void dynamic_scaled_int8_quant_kernel( __global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out, scalar_t const* __restrict__ input, int8_t* __restrict__ out,
...@@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel( ...@@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel(
} }
} }
template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void dynamic_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, azp_type* azp, const int hidden_size) {
int const token_idx = blockIdx.x;
// Scan for the min and max value for this token
float max_val = std::numeric_limits<float>::min();
float min_val = std::numeric_limits<float>::max();
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto val = static_cast<float>(input[token_idx * hidden_size + i]);
max_val = std::max(max_val, val);
min_val = std::min(min_val, val);
}
// Reduce the max and min values across the block
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
__syncthreads(); // Make sure min doesn't mess with max shared memory
min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);
__shared__ scale_type scale_sh;
__shared__ azp_type azp_sh;
// Compute the scale and zero point and store them, only on the first thread
if (threadIdx.x == 0) {
float const scale_val = (max_val - min_val) / 255.0f;
// Use rounding to even (same as torch.round)
auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
auto const azp_val = static_cast<azp_type>(azp_float);
// Store the scale and azp into shared and global
scale[token_idx] = scale_sh = scale_val;
azp[token_idx] = azp_sh = azp_val;
}
// Wait for the scale and azp to be computed
__syncthreads();
float const scale_val = scale_sh;
azp_type const azp_val = azp_sh;
// Quantize the values
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const quant_val =
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
out[token_idx * hidden_size + i] = quant_val;
}
}
} // namespace vllm } // namespace vllm
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& scale) { torch::Tensor const& scale,
c10::optional<torch::Tensor> const& azp) {
TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1); TORCH_CHECK(scale.numel() == 1);
TORCH_CHECK(!azp || azp->numel() == 1);
int const hidden_size = input.size(-1); int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size; int const num_tokens = input.numel() / hidden_size;
...@@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] ...@@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
if (!azp) {
vllm::static_scaled_int8_quant_kernel<scalar_t, float> vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), <<<grid, block, 0, stream>>>(
out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), hidden_size); scale.data_ptr<float>(), hidden_size);
} else {
vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), azp->data_ptr<int32_t>(),
hidden_size);
}
}); });
} }
void dynamic_scaled_int8_quant( void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size] torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size]
torch::Tensor& scales) { torch::Tensor& scales, c10::optional<torch::Tensor> const& azp) {
TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scales.is_contiguous());
TORCH_CHECK(!azp || azp->is_contiguous());
int const hidden_size = input.size(-1); int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size; int const num_tokens = input.numel() / hidden_size;
...@@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant( ...@@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
if (!azp) {
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float> vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), <<<grid, block, 0, stream>>>(
out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scales.data_ptr<float>(), hidden_size); scales.data_ptr<float>(), hidden_size);
} else {
vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scales.data_ptr<float>(), azp->data_ptr<int32_t>(),
hidden_size);
}
}); });
} }
...@@ -353,18 +353,47 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ ...@@ -353,18 +353,47 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
template<typename dst_t> template<typename dst_t>
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x; const int64_t i = blockIdx.x;
const block_iq1_s * x = (const block_iq1_s *) vx; const block_iq1_s * x = (const block_iq1_s *) vx;
const int tid = threadIdx.x; const int64_t tid = threadIdx.x;
const int il = tid/8; // 0...3 const int64_t il = tid/8; // 0...3
const int ib = tid%8; // 0...7 const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
const float d = __half2float(x[i].d) * (2*((x[i].qh[ib] >> 12) & 7) + 1);
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = __float2half(d * (q[j] + delta));
}
}
template<typename dst_t>
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq1_m * x = (const block_iq1_m *) vx;
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il; dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const int i8 = 4*ib+il; const uint16_t * sc = (const uint16_t *)x[i].scales;
uint8_t h = x[i].scales[i8/2] >> 4*(i8%2); iq1m_scale_t scale;
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5))); scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const float d = __half2float(x[i].d) * (2*(h & 7) + 1); const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j]); const float d = __half2float(scale.f16) * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = __float2half(d * (q[j] + delta));
}
} }
template<typename dst_t> template<typename dst_t>
...@@ -475,6 +504,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, c ...@@ -475,6 +504,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, c
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y); dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
} }
template<typename dst_t>
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t> template<typename dst_t>
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K; const int nb = (k + QK_K - 1) / QK_K;
...@@ -525,6 +560,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) { ...@@ -525,6 +560,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
return dequantize_row_iq2_s_cuda; return dequantize_row_iq2_s_cuda;
case 23: case 23:
return dequantize_row_iq4_xs_cuda; return dequantize_row_iq4_xs_cuda;
case 29:
return dequantize_row_iq1_m_cuda;
default: default:
return nullptr; return nullptr;
} }
......
This diff is collapsed.
...@@ -166,6 +166,11 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight ...@@ -166,6 +166,11 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
(void*)quant_X.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream); (half*)Y.data_ptr(), col, row, stream);
break; break;
case 29:
mul_mat_vec_iq1_m_q8_1_cuda((void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
} }
return Y; return Y;
} }
......
...@@ -157,6 +157,14 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * ...@@ -157,6 +157,14 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half *
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
......
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/vecdotq.cuh // copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/vecdotq.cuh
// and https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu // and https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
int x32 = x16[2*i32 + 0] << 0;
x32 |= x16[2*i32 + 1] << 16;
return x32;
}
static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) {
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
int x32 = 0; int x32 = 0;
...@@ -1661,24 +1674,76 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( ...@@ -1661,24 +1674,76 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
const block_iq1_s * bq1 = (const block_iq1_s *) vbq; const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
const int ib32 = iqs; const int qs_packed = get_int_b2(bq1->qs, iqs);
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; const uint8_t * qs = (const uint8_t *) &qs_packed;
const uint8_t h1 = bq1->scales[2*ib32+0];
const uint8_t h2 = bq1->scales[2*ib32+1]; const int qh = bq1->qh[iqs];
const int * q8 = (const int *)bq8_1[ib32].qs;
const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5))); int sumi = 0;
const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1))); #pragma unroll
const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5))); for (int l0 = 0; l0 < 8; l0 += 2) {
const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1))); const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)];
for (int j = 0; j < 2; ++j) {
sumi1 = __dp4a(q8[j+0], grid1[j], sumi1); const int grid0 = (grid >> 0) & 0x0F0F0F0F;
sumi2 = __dp4a(q8[j+2], grid2[j], sumi2); const int grid1 = (grid >> 4) & 0x0F0F0F0F;
sumi3 = __dp4a(q8[j+4], grid3[j], sumi3);
sumi4 = __dp4a(q8[j+6], grid4[j], sumi4); const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
} const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
const float d = __half2float(bq1->d) * __low2float(bq8_1[ib32].ds);
return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) + sumi = __dp4a(grid0, u0, sumi);
sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1)); sumi = __dp4a(grid1, u1, sumi);
}
const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1);
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
const float2 ds = __half22float2(bq8_1[iqs].ds);
return d1q * (ds.x*sumi + ds.y*delta);
#endif
}
static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
const int qs_packed = get_int_b4(bq1->qs, iqs);
const uint8_t * qs = (const uint8_t *) &qs_packed;
int sumi[2] = {0};
float sumf[2] = {0.0f};
#pragma unroll
for (int l0 = 0; l0 < 8; l0 += 2) {
const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2));
const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)];
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
sumi[l0/4] = __dp4a(grid0, u0, sumi[l0/4]);
sumi[l0/4] = __dp4a(grid1, u1, sumi[l0/4]);
const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08);
int sumy = 0;
sumy = __dp4a(u0, 0x01010101, sumy);
sumy = __dp4a(u1, 0x01010101, sumy);
sumf[l0/4] += delta*sumy;
}
const uint16_t * sc = (const uint16_t *) bq1->scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000);
const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds);
const int tmp = sc[iqs/2] >> (6*(iqs%2));
const int sc0 = 2*((tmp >> 0) & 0x07) + 1;
const int sc1 = 2*((tmp >> 3) & 0x07) + 1;
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
#endif #endif
} }
......
...@@ -157,7 +157,7 @@ TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput ...@@ -157,7 +157,7 @@ TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
@dataclass @dataclass(frozen=True)
class ScheduleConfig: class ScheduleConfig:
tile_shape_mn: Tuple[int, int] tile_shape_mn: Tuple[int, int]
cluster_shape_mnk: Tuple[int, int, int] cluster_shape_mnk: Tuple[int, int, int]
...@@ -328,56 +328,137 @@ def generate(): ...@@ -328,56 +328,137 @@ def generate():
# about how this works # about how this works
SCRIPT_DIR = os.path.dirname(__file__) SCRIPT_DIR = os.path.dirname(__file__)
schedules = [ schedule_common_params = dict(
ScheduleConfig( kernel_schedule=TmaMI,
tile_shape_mn=tile_shape_mn, epilogue_schedule=TmaCoop,
cluster_shape_mnk=cluster_shape_mnk, tile_scheduler=TileSchedulerType.StreamK,
kernel_schedule=kernel_schedule, )
epilogue_schedule=epilogue_schedule,
tile_scheduler=tile_scheduler,
) for tile_shape_mn, cluster_shape_mnk in (
((128, 16), (1, 1, 1)),
((128, 32), (1, 1, 1)),
((128, 64), (1, 1, 1)),
((128, 128), (1, 1, 1)),
) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, )
for tile_scheduler in (TileSchedulerType.StreamK, )
]
# For now we use the same heuristic for all types # For now we use the same heuristic for all types
# Heuristic is currently tuned for H100s
default_heuristic = [ default_heuristic = [
("M > 64", #### M = 257+
(
"M > 256 && K <= 16384 && N <= 4096",
ScheduleConfig( ScheduleConfig(
tile_shape_mn=(128, 128), tile_shape_mn=(128, 128),
cluster_shape_mnk=(1, 1, 1), cluster_shape_mnk=(2, 1, 1),
kernel_schedule=TmaMI, **schedule_common_params # type: ignore
epilogue_schedule=TmaCoop, )),
tile_scheduler=TileSchedulerType.StreamK, (
"M > 256",
ScheduleConfig(
tile_shape_mn=(128, 256),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)), )),
("M > 32", #### M = 129-256
(
"M > 128 && K <= 4096 && N <= 4096",
ScheduleConfig( ScheduleConfig(
tile_shape_mn=(128, 64), tile_shape_mn=(128, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 128 && K <= 8192 && N <= 8192",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 128",
ScheduleConfig(
tile_shape_mn=(128, 256),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 65-128
(
"M > 64 && K <= 4069 && N <= 4069",
ScheduleConfig(
tile_shape_mn=(128, 32),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 64 && K <= 4069 && N <= 8192",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 64 && K >= 8192 && N >= 12288",
ScheduleConfig(
tile_shape_mn=(256, 128),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 64",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 33-64
(
"M > 32 && K <= 6144 && N <= 6144",
ScheduleConfig(
tile_shape_mn=(128, 16),
cluster_shape_mnk=(1, 1, 1), cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI, **schedule_common_params # type: ignore
epilogue_schedule=TmaCoop, )),
tile_scheduler=TileSchedulerType.StreamK, (
"M > 32 && K >= 16384 && N >= 12288",
ScheduleConfig(
tile_shape_mn=(256, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 32",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)), )),
("M > 16", #### M = 17-32
(
"M > 16 && K <= 12288 && N <= 8192",
ScheduleConfig( ScheduleConfig(
tile_shape_mn=(128, 32), tile_shape_mn=(128, 32),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 16",
ScheduleConfig(
tile_shape_mn=(256, 32),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 1-16
(
"N >= 26624",
ScheduleConfig(
tile_shape_mn=(256, 16),
cluster_shape_mnk=(1, 1, 1), cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI, **schedule_common_params # type: ignore
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
)), )),
(None, (
ScheduleConfig(tile_shape_mn=(128, 16), None,
ScheduleConfig(
tile_shape_mn=(128, 16),
cluster_shape_mnk=(1, 1, 1), cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI, **schedule_common_params # type: ignore
epilogue_schedule=TmaCoop, )),
tile_scheduler=TileSchedulerType.StreamK))
] ]
schedules = list(set([x[1] for x in default_heuristic]))
impl_configs = [] impl_configs = []
GPTQ_kernel_type_configs = list( GPTQ_kernel_type_configs = list(
......
...@@ -152,7 +152,8 @@ struct MacheteKernelTemplate { ...@@ -152,7 +152,8 @@ struct MacheteKernelTemplate {
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
int const group_size = maybe_group_size.value_or(K); int const group_size =
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
int const scale_k = (K + group_size - 1) / group_size; int const scale_k = (K + group_size - 1) / group_size;
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment