"vllm/model_executor/models/lfm2_moe.py" did not exist on "c4e464333eac5a46e1cc2701e095a44057c82927"
Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
#pragma once
#include <torch/all.h>
#include "core/scalar_type.hpp"
torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
#include "core/registration.h"
#include "moe_ops.h"
#include "marlin_moe_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
......@@ -13,12 +12,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
// conditionally compiled so impl registration is in source file
#endif
}
......
......@@ -153,63 +153,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _zeros, int64_t split_k_iters,
int64_t thx, int64_t thy);
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
namespace machete {
std::vector<std::string> supported_schedules(
vllm::ScalarTypeTorchPtr const& btype);
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule);
torch::Tensor prepack_B(torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype);
}; // namespace machete
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp,
bool use_fp32_reduce);
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits);
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n);
......@@ -219,11 +164,6 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
......@@ -238,14 +178,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
torch::Tensor const& b_q_weight,
torch::Tensor const& s_tok,
torch::Tensor const& s_ch,
torch::Tensor const& s_group,
torch::Tensor& workspace, int64_t size_m,
int64_t size_n, int64_t size_k);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
......@@ -278,26 +210,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
std::vector<torch::Tensor> selective_scan_fwd(
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& C,
const c10::optional<torch::Tensor>& D_,
const c10::optional<torch::Tensor>& z_,
const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
const c10::optional<torch::Tensor>& index_,
const c10::optional<torch::Tensor>& x);
at::Tensor causal_conv1d_update(
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias, bool silu_activation,
const c10::optional<at::Tensor>& conv_state_indices);
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& seq_idx_,
const c10::optional<at::Tensor>& initial_states_,
const c10::optional<at::Tensor>& final_states_out_,
bool silu_activation);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& C,
const c10::optional<torch::Tensor>& D_,
const c10::optional<torch::Tensor>& z_,
const c10::optional<torch::Tensor>& delta_bias_,
bool delta_softplus,
const c10::optional<torch::Tensor>& query_start_loc,
const c10::optional<torch::Tensor>& cache_indices,
const c10::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id);
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
bool silu_activation,
const c10::optional<at::Tensor>& cache_seqlens_,
const c10::optional<at::Tensor>& conv_state_indices_,
int64_t pad_slot_id);
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& conv_states,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool silu_activation, int64_t pad_slot_id);
#ifndef USE_ROCM
using fptr_t = int64_t;
......
......@@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
long const* sampled_token_ids_ptr, long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
int64_t const block_tables_stride) {
int const n_pad = num_seqs - num_queries;
if (n_pad && blockIdx.x == 0) {
// Handle cuda graph padding
int const offset = num_queries;
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
input_tokens_ptr[offset + i] = 0;
input_positions_ptr[offset + i] = 0;
slot_mapping_ptr[offset + i] = -1;
}
}
int num_query_blocks = div_ceil(num_queries, num_threads);
if (blockIdx.x >= num_query_blocks) {
......@@ -52,7 +63,7 @@ __global__ void advance_step_flashattn_kernel(
slot_mapping_ptr[cur_query_id] = slot_num;
}
inline void verify_tensor(std::string const& name, torch::Tensor& t,
inline void verify_tensor(std::string const& name, torch::Tensor const& t,
int64_t const size_0, int64_t const size_1,
c10::ScalarType const type) {
bool size_0_cond = true;
......@@ -211,7 +222,7 @@ void advance_step_flashinfer(
printf(" num_seqs = %d\n", num_seqs);
printf(" num_queries = %d\n", num_queries);
printf(" block_size = %d\n", block_size);
printf(" block_tables.stride(0) = %d\n", block_tables.stride(0));
printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0));
}
// Verify all tensors
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
......@@ -303,4 +314,4 @@ void advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
}
\ No newline at end of file
}
......@@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
}
}
......@@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel(
scale_type const* scale_ptr, azp_type const* azp_ptr,
const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
azp_type const azp = *azp_ptr;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const val = static_cast<float>(input[i]);
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
out[token_idx * hidden_size + i] = quant_val;
out[i] = quant_val;
}
}
......@@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
int64_t const token_idx = blockIdx.x;
float absmax_val = 0.0f;
float const zero = 0.0f;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[token_idx * hidden_size + i]);
float val = static_cast<float>(input[i]);
val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val;
}
......@@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
}
}
......@@ -159,13 +169,17 @@ template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void dynamic_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, azp_type* azp, const int hidden_size) {
int const token_idx = blockIdx.x;
int64_t const token_idx = blockIdx.x;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
// Scan for the min and max value for this token
float max_val = std::numeric_limits<float>::min();
float min_val = std::numeric_limits<float>::max();
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto val = static_cast<float>(input[token_idx * hidden_size + i]);
auto val = static_cast<float>(input[i]);
max_val = std::max(max_val, val);
min_val = std::min(min_val, val);
}
......@@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
// Quantize the values
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const val = static_cast<float>(input[i]);
auto const quant_val =
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
out[token_idx * hidden_size + i] = quant_val;
out[i] = quant_val;
}
}
......
......@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
......@@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
// Hopper
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
#else
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
} else if (version_num == 89) {
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
} else if (version_num >= 80) {
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
} else {
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
return;
}
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
......@@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"currently bias dtype must match output dtype ", c.dtype());
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
#else
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
#endif
} else if (version_num == 89) {
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} else if (version_num >= 80) {
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} else {
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ",
version_num);
}
\ No newline at end of file
......@@ -204,8 +204,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
// Use int64 to avoid overflowing an int32 when calculating this offset
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
scalar_t const* __restrict__ token_input = &input[offset];
FP8_TYPE* __restrict__ token_output = &out[offset];
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
......
......@@ -22,6 +22,8 @@
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"
#include "core/registration.h"
using namespace marlin;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
......@@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
}
\ No newline at end of file
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
#include "core/registration.h"
namespace marlin {
......@@ -122,7 +103,7 @@ __global__ void awq_marlin_repack_kernel(
}
uint32_t vals[8];
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
......@@ -143,7 +124,7 @@ __global__ void awq_marlin_repack_kernel(
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
......@@ -155,7 +136,7 @@ __global__ void awq_marlin_repack_kernel(
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
......@@ -167,21 +148,21 @@ __global__ void awq_marlin_repack_kernel(
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
......@@ -195,15 +176,15 @@ __global__ void awq_marlin_repack_kernel(
} // namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) {
......@@ -266,8 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
return out;
}
#endif
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits) {
......@@ -279,3 +258,11 @@ torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
}
\ No newline at end of file
......@@ -23,6 +23,8 @@
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
......@@ -2258,7 +2260,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"b_zeros dim 0 = ", b_zeros.size(0),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
"b_zeros dim 1 = ", b_scales.size(1),
"b_zeros dim 1 = ", b_zeros.size(1),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
......@@ -2297,3 +2299,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
}
\ No newline at end of file
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
#include "core/registration.h"
namespace marlin {
......@@ -174,13 +154,13 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints];
#pragma unroll
#pragma unroll
for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor;
......@@ -200,7 +180,7 @@ __global__ void gptq_marlin_repack_kernel(
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
......@@ -212,7 +192,7 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
......@@ -224,14 +204,14 @@ __global__ void gptq_marlin_repack_kernel(
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
......@@ -242,7 +222,7 @@ __global__ void gptq_marlin_repack_kernel(
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
......@@ -256,17 +236,17 @@ __global__ void gptq_marlin_repack_kernel(
} // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
......@@ -341,8 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
return out;
}
#endif
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits) {
......@@ -354,3 +332,11 @@ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
}
\ No newline at end of file
......@@ -284,7 +284,7 @@ mm_impl_template = create_template(IMPL_TEMPLATE)
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
def create_sources(impl_config: ImplConfig, num_impl_files=2):
def create_sources(impl_config: ImplConfig, num_impl_files=1):
sources = []
type_name = generate_type_signature(impl_config.type_config)
......@@ -457,7 +457,13 @@ def generate():
)),
]
schedules = list(set([x[1] for x in default_heuristic]))
# Do not use schedules = list(set(...)) because we need to make sure
# the output list is deterministic; otherwise the generated kernel file
# will be non-deterministic and causes ccache miss.
schedules = []
for _, schedule_config in default_heuristic:
if schedule_config not in schedules:
schedules.append(schedule_config)
impl_configs = []
......
......@@ -591,24 +591,27 @@ struct MacheteCollectiveMma {
tma_load_b = make_tma_copy_B(
make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB));
int32_t scale_k =
(ModeHasScales) ? (K + args.group_size - 1) / args.group_size : 0;
int32_t group_size = (ModeHasScales) ? args.group_size : 0;
if constexpr (ModeHasScales) {
tma_load_scale = make_tma_copy_scale(make_logical_tensor(
args.ptr_S, make_shape(M, args.group_size, L), args.dS));
tma_load_scale = make_tma_copy_scale(
make_logical_tensor(args.ptr_S, make_shape(M, scale_k, L), args.dS));
}
if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
tma_load_zero = make_tma_copy_zero(make_logical_tensor(
args.ptr_Z, make_shape(M, args.group_size, L), args.dS));
tma_load_zero = make_tma_copy_zero(
make_logical_tensor(args.ptr_Z, make_shape(M, scale_k, L), args.dS));
}
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0};
} else if constexpr (ModeHasScales) {
auto scale_k = (K + args.group_size - 1) / args.group_size;
if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
KernelConversionMode == ConversionMode::ConvertAndScale ||
KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
return {tma_load_a, tma_load_b, tma_load_scale,
tma_load_zero, scale_k, args.group_size};
tma_load_zero, scale_k, group_size};
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in to_underlying_arguments.");
......
......@@ -34,10 +34,9 @@ static __global__ void prepack_B_kernel(BInTensor B_in,
}
template <typename PrepackedLayoutB, typename InLayout>
static void prepack_B(cudaStream_t stream,
typename PrepackedLayoutB::ElementB const* B_in_ptr,
InLayout B_layout,
typename PrepackedLayoutB::ElementB* B_out_ptr) {
static void prepack_B_template(
cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr,
InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) {
using TileShapeNKL =
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
auto ilvd_NKbNbKL_to_offset =
......
......@@ -55,8 +55,8 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
// Allocate output
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
static_cast<ElementB*>(D.mutable_data_ptr()));
prepack_B_template<PrepackedLayoutB>(
stream, B_ptr, layout_Bt, static_cast<ElementB*>(D.mutable_data_ptr()));
return D;
};
......
......@@ -2,6 +2,8 @@
#include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp"
#include "core/registration.h"
namespace machete {
using namespace vllm;
......@@ -78,14 +80,20 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
}
torch::Tensor prepack_B(torch::Tensor const& B,
ScalarTypeTorchPtr const& btype) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
vllm::ScalarTypeTorchPtr const& btype) {
return scalar_type_dispatch(*btype, [&](auto BType) {
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("machete_prepack_B", &prepack_B);
m.impl("machete_gemm", &gemm);
}
// use CatchAll since supported_schedules has no tensor arguments
TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CatchAll, m) {
m.impl("machete_supported_schedules", &supported_schedules);
}
}; // namespace machete
......@@ -26,6 +26,7 @@
#include <iostream>
#include "common/base.h"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "common/mem.h"
......@@ -1066,3 +1067,7 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_gemm", &marlin_gemm);
}
......@@ -30,6 +30,7 @@
#include <iostream>
#include "../dense/common/base.h"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "../dense/common/mem.h"
......@@ -1241,3 +1242,7 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
return d;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_qqq_gemm", &marlin_qqq_gemm);
}
......@@ -28,6 +28,7 @@
#include "common/base.h"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
......@@ -1134,3 +1135,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
}
......@@ -260,7 +260,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
// conditionally compiled so impl in source file
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops.def(
......@@ -268,22 +268,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_scales, Tensor workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k) -> Tensor");
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
// conditionally compiled so impl in source file
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops.def("machete_supported_schedules", &machete::supported_schedules);
ops.def(
"machete_supported_schedules("
" __torch__.torch.classes._core_C.ScalarType btype"
") -> str[]");
ops.def(
"machete_gemm(Tensor A, Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype,"
" Tensor? scales, Tensor? zeros, int? group_size,"
" Tensor? C, float? alpha, float? beta, str? schedule)"
"-> Tensor");
ops.impl("machete_gemm", torch::kCUDA, &machete::gemm);
ops.def(
"machete_prepack_B(Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype)"
"-> Tensor");
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
// conditionally compiled so impl registration is in source file
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
......@@ -295,21 +297,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor");
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
// conditionally compiled so impl registration is in source file
// gptq_marlin repack from GPTQ.
ops.def(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);
// conditionally compiled so impl registrations are in source file
// awq_marlin repack from AWQ.
ops.def(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor");
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);
// conditionally compiled so impl registrations are in source file
// Dequantization for GGML.
ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
......@@ -330,7 +330,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, int size_m, int size_n, "
"int size_k) -> Tensor");
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
// conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ.
ops.def(
......@@ -338,7 +338,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
"Tensor! workspace, int size_m, int size_n, "
"int size_k) -> Tensor");
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
// conditionally compiled so impl registration is in source file
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
......@@ -366,27 +366,35 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? index_, Tensor!? x) -> Tensor[]");
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states,"
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? conv_state_indices) -> Tensor");
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor? seq_idx_,"
"Tensor? initial_states_,"
"Tensor!? final_states_out_,"
"bool silu_activation) -> Tensor");
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment