"vscode:/vscode.git/clone" did not exist on "2f9812aad6e7bf028e6b36ce033d488cd84cdf0f"
Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
#pragma once
#include <torch/all.h>
#include "core/scalar_type.hpp"
torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
#include "core/registration.h" #include "core/registration.h"
#include "moe_ops.h" #include "moe_ops.h"
#include "marlin_moe_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs. // Apply topk softmax to the gating outputs.
...@@ -13,12 +12,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -13,12 +12,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, " "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, " "int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)" "int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"); " -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); // conditionally compiled so impl registration is in source file
#endif #endif
} }
......
...@@ -153,63 +153,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, ...@@ -153,63 +153,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _zeros, int64_t split_k_iters, torch::Tensor _zeros, int64_t split_k_iters,
int64_t thx, int64_t thy); int64_t thx, int64_t thy);
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
namespace machete {
std::vector<std::string> supported_schedules(
vllm::ScalarTypeTorchPtr const& btype);
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule);
torch::Tensor prepack_B(torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype);
}; // namespace machete
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp,
bool use_fp32_reduce);
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits);
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n); int64_t n);
...@@ -219,11 +164,6 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, ...@@ -219,11 +164,6 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row); int64_t row);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
...@@ -238,14 +178,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, ...@@ -238,14 +178,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& azp_adj, torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp, c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias); c10::optional<torch::Tensor> const& bias);
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
torch::Tensor const& b_q_weight,
torch::Tensor const& s_tok,
torch::Tensor const& s_ch,
torch::Tensor const& s_group,
torch::Tensor& workspace, int64_t size_m,
int64_t size_n, int64_t size_k);
#endif #endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
...@@ -278,26 +210,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -278,26 +210,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
std::vector<torch::Tensor> selective_scan_fwd( void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& B, const torch::Tensor& C, const torch::Tensor& C,
const c10::optional<torch::Tensor>& D_, const c10::optional<torch::Tensor>& D_,
const c10::optional<torch::Tensor>& z_, const c10::optional<torch::Tensor>& z_,
const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus, const c10::optional<torch::Tensor>& delta_bias_,
const c10::optional<torch::Tensor>& index_, bool delta_softplus,
const c10::optional<torch::Tensor>& x); const c10::optional<torch::Tensor>& query_start_loc,
const c10::optional<torch::Tensor>& cache_indices,
at::Tensor causal_conv1d_update( const c10::optional<torch::Tensor>& has_initial_state,
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, const torch::Tensor& ssm_states, int64_t pad_slot_id);
const c10::optional<at::Tensor>& bias, bool silu_activation,
const c10::optional<at::Tensor>& conv_state_indices); void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
const at::Tensor& weight,
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& bias_, bool silu_activation,
const c10::optional<at::Tensor>& seq_idx_, const c10::optional<at::Tensor>& cache_seqlens_,
const c10::optional<at::Tensor>& initial_states_, const c10::optional<at::Tensor>& conv_state_indices_,
const c10::optional<at::Tensor>& final_states_out_, int64_t pad_slot_id);
bool silu_activation);
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& conv_states,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool silu_activation, int64_t pad_slot_id);
#ifndef USE_ROCM #ifndef USE_ROCM
using fptr_t = int64_t; using fptr_t = int64_t;
......
...@@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel( ...@@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
long const* sampled_token_ids_ptr, long* input_positions_ptr, long const* sampled_token_ids_ptr, long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
int64_t const block_tables_stride) { int64_t const block_tables_stride) {
int const n_pad = num_seqs - num_queries;
if (n_pad && blockIdx.x == 0) {
// Handle cuda graph padding
int const offset = num_queries;
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
input_tokens_ptr[offset + i] = 0;
input_positions_ptr[offset + i] = 0;
slot_mapping_ptr[offset + i] = -1;
}
}
int num_query_blocks = div_ceil(num_queries, num_threads); int num_query_blocks = div_ceil(num_queries, num_threads);
if (blockIdx.x >= num_query_blocks) { if (blockIdx.x >= num_query_blocks) {
...@@ -52,7 +63,7 @@ __global__ void advance_step_flashattn_kernel( ...@@ -52,7 +63,7 @@ __global__ void advance_step_flashattn_kernel(
slot_mapping_ptr[cur_query_id] = slot_num; slot_mapping_ptr[cur_query_id] = slot_num;
} }
inline void verify_tensor(std::string const& name, torch::Tensor& t, inline void verify_tensor(std::string const& name, torch::Tensor const& t,
int64_t const size_0, int64_t const size_1, int64_t const size_0, int64_t const size_1,
c10::ScalarType const type) { c10::ScalarType const type) {
bool size_0_cond = true; bool size_0_cond = true;
...@@ -211,7 +222,7 @@ void advance_step_flashinfer( ...@@ -211,7 +222,7 @@ void advance_step_flashinfer(
printf(" num_seqs = %d\n", num_seqs); printf(" num_seqs = %d\n", num_seqs);
printf(" num_queries = %d\n", num_queries); printf(" num_queries = %d\n", num_queries);
printf(" block_size = %d\n", block_size); printf(" block_size = %d\n", block_size);
printf(" block_tables.stride(0) = %d\n", block_tables.stride(0)); printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0));
} }
// Verify all tensors // Verify all tensors
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
...@@ -303,4 +314,4 @@ void advance_step_flashinfer( ...@@ -303,4 +314,4 @@ void advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
paged_kv_indptr, paged_kv_last_page_len, block_table_bound); paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
} }
\ No newline at end of file
...@@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel( ...@@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out, scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) { scale_type const* scale_ptr, const int hidden_size) {
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr; scale_type const scale = *scale_ptr;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn( out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
} }
} }
...@@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel( ...@@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel(
scale_type const* scale_ptr, azp_type const* azp_ptr, scale_type const* scale_ptr, azp_type const* azp_ptr,
const int hidden_size) { const int hidden_size) {
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr; scale_type const scale = *scale_ptr;
azp_type const azp = *azp_ptr; azp_type const azp = *azp_ptr;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]); auto const val = static_cast<float>(input[i]);
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
out[token_idx * hidden_size + i] = quant_val; out[i] = quant_val;
} }
} }
...@@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel( ...@@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out, scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) { scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int64_t const token_idx = blockIdx.x;
float absmax_val = 0.0f; float absmax_val = 0.0f;
float const zero = 0.0f; float const zero = 0.0f;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[token_idx * hidden_size + i]); float val = static_cast<float>(input[i]);
val = val > zero ? val : -val; val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val; absmax_val = val > absmax_val ? val : absmax_val;
} }
...@@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel( ...@@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
float const tmp_scale = 127.0f / block_absmax_val; float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn( out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
} }
} }
...@@ -159,13 +169,17 @@ template <typename scalar_t, typename scale_type, typename azp_type> ...@@ -159,13 +169,17 @@ template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void dynamic_scaled_int8_azp_quant_kernel( __global__ void dynamic_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out, scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, azp_type* azp, const int hidden_size) { scale_type* scale, azp_type* azp, const int hidden_size) {
int const token_idx = blockIdx.x; int64_t const token_idx = blockIdx.x;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
// Scan for the min and max value for this token // Scan for the min and max value for this token
float max_val = std::numeric_limits<float>::min(); float max_val = std::numeric_limits<float>::min();
float min_val = std::numeric_limits<float>::max(); float min_val = std::numeric_limits<float>::max();
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto val = static_cast<float>(input[token_idx * hidden_size + i]); auto val = static_cast<float>(input[i]);
max_val = std::max(max_val, val); max_val = std::max(max_val, val);
min_val = std::min(min_val, val); min_val = std::min(min_val, val);
} }
...@@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( ...@@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
// Quantize the values // Quantize the values
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]); auto const val = static_cast<float>(input[i]);
auto const quant_val = auto const quant_val =
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
out[token_idx * hidden_size + i] = quant_val; out[i] = quant_val;
} }
} }
......
...@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a, ...@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias); c10::optional<torch::Tensor> const& bias);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 #if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
...@@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, ...@@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
if (version_num >= 90) { // Hopper
// Hopper
// Guard against compilation issues for sm90 kernels // Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 #if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias); cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
#else return;
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias); }
#endif #endif
} else if (version_num == 89) {
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace // Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias); cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
} else if (version_num >= 80) { return;
}
if (version_num >= 80) {
// Ampere // Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias); cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
} else { return;
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
} }
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
} }
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
...@@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, ...@@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"currently bias dtype must match output dtype ", c.dtype()); "currently bias dtype must match output dtype ", c.dtype());
at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
// Hopper
// Guard against compilation issues for sm90 kernels #if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias); cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
#else return;
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias); }
#endif #endif
} else if (version_num == 89) {
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace // Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias); cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} else if (version_num >= 80) { return;
}
if (version_num >= 80) {
// Ampere // Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias); cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} else { return;
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} }
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ",
version_num);
} }
\ No newline at end of file
...@@ -204,8 +204,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( ...@@ -204,8 +204,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int const token_idx = blockIdx.x;
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size]; // Use int64 to avoid overflowing an int32 when calculating this offset
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size]; int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
scalar_t const* __restrict__ token_input = &input[offset];
FP8_TYPE* __restrict__ token_output = &out[offset];
// For vectorization, token_input and token_output pointers need to be // For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively. // aligned at 8-byte and 4-byte addresses respectively.
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include "../gptq_marlin/marlin.cuh" #include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh" #include "../gptq_marlin/marlin_dtypes.cuh"
#include "core/registration.h"
using namespace marlin; using namespace marlin;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...@@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
} }
#endif #endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
}
\ No newline at end of file
#include "marlin.cuh" #include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #include "core/registration.h"
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
namespace marlin { namespace marlin {
...@@ -122,7 +103,7 @@ __global__ void awq_marlin_repack_kernel( ...@@ -122,7 +103,7 @@ __global__ void awq_marlin_repack_kernel(
} }
uint32_t vals[8]; uint32_t vals[8];
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i]; int cur_elem = tc_row + tc_offsets[i];
...@@ -143,7 +124,7 @@ __global__ void awq_marlin_repack_kernel( ...@@ -143,7 +124,7 @@ __global__ void awq_marlin_repack_kernel(
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4); res |= vals[pack_idx[i]] << (i * 4);
} }
...@@ -155,7 +136,7 @@ __global__ void awq_marlin_repack_kernel( ...@@ -155,7 +136,7 @@ __global__ void awq_marlin_repack_kernel(
uint32_t res1 = 0; uint32_t res1 = 0;
uint32_t res2 = 0; uint32_t res2 = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8); res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8); res2 |= vals[4 + pack_idx[i]] << (i * 8);
...@@ -167,21 +148,21 @@ __global__ void awq_marlin_repack_kernel( ...@@ -167,21 +148,21 @@ __global__ void awq_marlin_repack_kernel(
}; };
auto start_pipes = [&](int k_tile_id, int n_tile_id) { auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) { for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
} }
wait_for_stage(); wait_for_stage();
}; };
#pragma unroll #pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0; int n_tile_id = 0;
start_pipes(k_tile_id, n_tile_id); start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) { while (n_tile_id < n_tiles) {
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) { for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1); n_tile_id + pipe + repack_stages - 1);
...@@ -195,15 +176,15 @@ __global__ void awq_marlin_repack_kernel( ...@@ -195,15 +176,15 @@ __global__ void awq_marlin_repack_kernel(
} // namespace marlin } // namespace marlin
#define CALL_IF(NUM_BITS) \ #define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \ else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \ b_q_weight_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) { int64_t size_n, int64_t num_bits) {
...@@ -266,8 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, ...@@ -266,8 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
return out; return out;
} }
#endif
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight, torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n, c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits) { int64_t num_bits) {
...@@ -279,3 +258,11 @@ torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight, ...@@ -279,3 +258,11 @@ torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options); options);
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
}
\ No newline at end of file
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "marlin_dtypes.cuh" #include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#include "core/registration.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \ static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \ std::is_same<scalar_t, nv_bfloat16>::value, \
...@@ -2258,7 +2260,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -2258,7 +2260,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"b_zeros dim 0 = ", b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0),
" is not num_groups = ", num_groups); " is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
"b_zeros dim 1 = ", b_scales.size(1), "b_zeros dim 1 = ", b_zeros.size(1),
" is not size_n / pack_factor = ", size_n / pack_factor); " is not size_n / pack_factor = ", size_n / pack_factor);
} }
...@@ -2297,3 +2299,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -2297,3 +2299,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
} }
#endif #endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
}
\ No newline at end of file
#include "marlin.cuh" #include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #include "core/registration.h"
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
namespace marlin { namespace marlin {
...@@ -174,13 +154,13 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -174,13 +154,13 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t b1_vals[tile_ints]; uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints]; uint32_t b2_vals[tile_ints];
#pragma unroll #pragma unroll
for (int i = 0; i < tile_ints; i++) { for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
} }
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i]; int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor; int cur_int = cur_elem / pack_factor;
...@@ -200,7 +180,7 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -200,7 +180,7 @@ __global__ void gptq_marlin_repack_kernel(
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4); res |= vals[pack_idx[i]] << (i * 4);
} }
...@@ -212,7 +192,7 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -212,7 +192,7 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t res1 = 0; uint32_t res1 = 0;
uint32_t res2 = 0; uint32_t res2 = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8); res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8); res2 |= vals[4 + pack_idx[i]] << (i * 8);
...@@ -224,14 +204,14 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -224,14 +204,14 @@ __global__ void gptq_marlin_repack_kernel(
}; };
auto start_pipes = [&](int k_tile_id, int n_tile_id) { auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) { for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
} }
wait_for_stage(); wait_for_stage();
}; };
#pragma unroll #pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0; int n_tile_id = 0;
...@@ -242,7 +222,7 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -242,7 +222,7 @@ __global__ void gptq_marlin_repack_kernel(
start_pipes(k_tile_id, n_tile_id); start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) { while (n_tile_id < n_tiles) {
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) { for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1); n_tile_id + pipe + repack_stages - 1);
...@@ -256,17 +236,17 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -256,17 +236,17 @@ __global__ void gptq_marlin_repack_kernel(
} // namespace marlin } // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \ #define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \ HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \ HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
...@@ -341,8 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, ...@@ -341,8 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
return out; return out;
} }
#endif
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k, torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits) { c10::SymInt size_n, int64_t num_bits) {
...@@ -354,3 +332,11 @@ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, ...@@ -354,3 +332,11 @@ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options); options);
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
}
\ No newline at end of file
...@@ -284,7 +284,7 @@ mm_impl_template = create_template(IMPL_TEMPLATE) ...@@ -284,7 +284,7 @@ mm_impl_template = create_template(IMPL_TEMPLATE)
prepack_dispatch_template = create_template(PREPACK_TEMPLATE) prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
def create_sources(impl_config: ImplConfig, num_impl_files=2): def create_sources(impl_config: ImplConfig, num_impl_files=1):
sources = [] sources = []
type_name = generate_type_signature(impl_config.type_config) type_name = generate_type_signature(impl_config.type_config)
...@@ -457,7 +457,13 @@ def generate(): ...@@ -457,7 +457,13 @@ def generate():
)), )),
] ]
schedules = list(set([x[1] for x in default_heuristic])) # Do not use schedules = list(set(...)) because we need to make sure
# the output list is deterministic; otherwise the generated kernel file
# will be non-deterministic and causes ccache miss.
schedules = []
for _, schedule_config in default_heuristic:
if schedule_config not in schedules:
schedules.append(schedule_config)
impl_configs = [] impl_configs = []
......
...@@ -591,24 +591,27 @@ struct MacheteCollectiveMma { ...@@ -591,24 +591,27 @@ struct MacheteCollectiveMma {
tma_load_b = make_tma_copy_B( tma_load_b = make_tma_copy_B(
make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB)); make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB));
int32_t scale_k =
(ModeHasScales) ? (K + args.group_size - 1) / args.group_size : 0;
int32_t group_size = (ModeHasScales) ? args.group_size : 0;
if constexpr (ModeHasScales) { if constexpr (ModeHasScales) {
tma_load_scale = make_tma_copy_scale(make_logical_tensor( tma_load_scale = make_tma_copy_scale(
args.ptr_S, make_shape(M, args.group_size, L), args.dS)); make_logical_tensor(args.ptr_S, make_shape(M, scale_k, L), args.dS));
} }
if constexpr (KernelConversionMode == if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) { ConversionMode::ConvertAndScaleWithZero) {
tma_load_zero = make_tma_copy_zero(make_logical_tensor( tma_load_zero = make_tma_copy_zero(
args.ptr_Z, make_shape(M, args.group_size, L), args.dS)); make_logical_tensor(args.ptr_Z, make_shape(M, scale_k, L), args.dS));
} }
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0}; KernelConversionMode == ConversionMode::ConvertAndScale ||
} else if constexpr (ModeHasScales) { KernelConversionMode ==
auto scale_k = (K + args.group_size - 1) / args.group_size; ConversionMode::ConvertAndScaleWithZero) {
return {tma_load_a, tma_load_b, tma_load_scale, return {tma_load_a, tma_load_b, tma_load_scale,
tma_load_zero, scale_k, args.group_size}; tma_load_zero, scale_k, group_size};
} else { } else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in to_underlying_arguments."); "Conversion mode not handled in to_underlying_arguments.");
......
...@@ -34,10 +34,9 @@ static __global__ void prepack_B_kernel(BInTensor B_in, ...@@ -34,10 +34,9 @@ static __global__ void prepack_B_kernel(BInTensor B_in,
} }
template <typename PrepackedLayoutB, typename InLayout> template <typename PrepackedLayoutB, typename InLayout>
static void prepack_B(cudaStream_t stream, static void prepack_B_template(
typename PrepackedLayoutB::ElementB const* B_in_ptr, cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr,
InLayout B_layout, InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) {
typename PrepackedLayoutB::ElementB* B_out_ptr) {
using TileShapeNKL = using TileShapeNKL =
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{})); decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
auto ilvd_NKbNbKL_to_offset = auto ilvd_NKbNbKL_to_offset =
......
...@@ -55,8 +55,8 @@ torch::Tensor prepack_impl(torch::Tensor const B) { ...@@ -55,8 +55,8 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
// Allocate output // Allocate output
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous); torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt, prepack_B_template<PrepackedLayoutB>(
static_cast<ElementB*>(D.mutable_data_ptr())); stream, B_ptr, layout_Bt, static_cast<ElementB*>(D.mutable_data_ptr()));
return D; return D;
}; };
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include "machete_prepack_launcher.cuh" #include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#include "core/registration.h"
namespace machete { namespace machete {
using namespace vllm; using namespace vllm;
...@@ -78,14 +80,20 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, ...@@ -78,14 +80,20 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
} }
torch::Tensor prepack_B(torch::Tensor const& B, torch::Tensor prepack_B(torch::Tensor const& B,
ScalarTypeTorchPtr const& btype) { vllm::ScalarTypeTorchPtr const& btype) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return scalar_type_dispatch(*btype, [&](auto BType) { return scalar_type_dispatch(*btype, [&](auto BType) {
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B); return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
}); });
#else }
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("machete_prepack_B", &prepack_B);
m.impl("machete_gemm", &gemm);
}
// use CatchAll since supported_schedules has no tensor arguments
TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CatchAll, m) {
m.impl("machete_supported_schedules", &supported_schedules);
} }
}; // namespace machete }; // namespace machete
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <iostream> #include <iostream>
#include "common/base.h" #include "common/base.h"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "common/mem.h" #include "common/mem.h"
...@@ -1066,3 +1067,7 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1066,3 +1067,7 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return c; return c;
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_gemm", &marlin_gemm);
}
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <iostream> #include <iostream>
#include "../dense/common/base.h" #include "../dense/common/base.h"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "../dense/common/mem.h" #include "../dense/common/mem.h"
...@@ -1241,3 +1242,7 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, ...@@ -1241,3 +1242,7 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
return d; return d;
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_qqq_gemm", &marlin_qqq_gemm);
}
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "common/base.h" #include "common/base.h"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
...@@ -1134,3 +1135,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1134,3 +1135,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return c; return c;
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
}
...@@ -260,7 +260,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -260,7 +260,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def( ops.def(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor"); "Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm); // conditionally compiled so impl in source file
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops.def( ops.def(
...@@ -268,22 +268,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -268,22 +268,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_scales, Tensor workspace, " "Tensor b_scales, Tensor workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, " "__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k) -> Tensor"); "int size_m, int size_n, int size_k) -> Tensor");
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); // conditionally compiled so impl in source file
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper. // Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops.def("machete_supported_schedules", &machete::supported_schedules); ops.def(
"machete_supported_schedules("
" __torch__.torch.classes._core_C.ScalarType btype"
") -> str[]");
ops.def( ops.def(
"machete_gemm(Tensor A, Tensor B," "machete_gemm(Tensor A, Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype," " __torch__.torch.classes._core_C.ScalarType btype,"
" Tensor? scales, Tensor? zeros, int? group_size," " Tensor? scales, Tensor? zeros, int? group_size,"
" Tensor? C, float? alpha, float? beta, str? schedule)" " Tensor? C, float? alpha, float? beta, str? schedule)"
"-> Tensor"); "-> Tensor");
ops.impl("machete_gemm", torch::kCUDA, &machete::gemm);
ops.def( ops.def(
"machete_prepack_B(Tensor B," "machete_prepack_B(Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype)" " __torch__.torch.classes._core_C.ScalarType btype)"
"-> Tensor"); "-> Tensor");
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); // conditionally compiled so impl registration is in source file
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
ops.impl("permute_cols", torch::kCUDA, &permute_cols); ops.impl("permute_cols", torch::kCUDA, &permute_cols);
...@@ -295,21 +297,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -295,21 +297,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"__torch__.torch.classes._core_C.ScalarType b_q_type, " "__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k, bool is_k_full, " "int size_m, int size_n, int size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor"); "bool has_zp, bool use_fp32_reduce) -> Tensor");
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); // conditionally compiled so impl registration is in source file
// gptq_marlin repack from GPTQ. // gptq_marlin repack from GPTQ.
ops.def( ops.def(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, " "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor"); "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); // conditionally compiled so impl registrations are in source file
ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);
// awq_marlin repack from AWQ. // awq_marlin repack from AWQ.
ops.def( ops.def(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor"); "SymInt size_n, int num_bits) -> Tensor");
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack); // conditionally compiled so impl registrations are in source file
ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);
// Dequantization for GGML. // Dequantization for GGML.
ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor"); ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
...@@ -330,7 +330,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -330,7 +330,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, int size_m, int size_n, " "Tensor! workspace, int num_bits, int size_m, int size_n, "
"int size_k) -> Tensor"); "int size_k) -> Tensor");
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm); // conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ. // marlin_qqq_gemm for QQQ.
ops.def( ops.def(
...@@ -338,7 +338,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -338,7 +338,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor s_tok, Tensor s_ch, Tensor s_group, " "Tensor s_tok, Tensor s_ch, Tensor s_group, "
"Tensor! workspace, int size_m, int size_n, " "Tensor! workspace, int size_m, int size_n, "
"int size_k) -> Tensor"); "int size_k) -> Tensor");
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm); // conditionally compiled so impl registration is in source file
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias // quantization, as well as bias
...@@ -366,27 +366,35 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -366,27 +366,35 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def( ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta," "selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C," "Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_," "Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus," "bool delta_softplus,"
"Tensor? index_, Tensor!? x) -> Tensor[]"); "Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states,"
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def( ops.def(
"causal_conv1d_update(Tensor! x," "causal_conv1d_update(Tensor! x,"
"Tensor! conv_state," "Tensor! conv_state,"
"Tensor! weight," "Tensor! weight,"
"Tensor? bias," "Tensor? bias_,"
"bool silu_activation," "bool silu_activation,"
"Tensor? conv_state_indices) -> Tensor"); "Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def( ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight," "causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_," "Tensor? bias_,"
"Tensor? seq_idx_," "Tensor!? conv_states,"
"Tensor? initial_states_," "Tensor? query_start_loc,"
"Tensor!? final_states_out_," "Tensor? cache_indices,"
"bool silu_activation) -> Tensor"); "Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment