Unverified Commit f57d2dc1 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[sgl-kernel] avoid per_token_quant_fp8.cu hardcode sm_count (#8738)

parent f2d68ded
...@@ -173,9 +173,8 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: ...@@ -173,9 +173,8 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim); TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Hard-code sm_count const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
int sm_count = 132; const int TOKENS_PER_CTA = 8;
constexpr int TOKENS_PER_CTA = 8;
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA); const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
const bool use_vec16 = (hidden_dim % 16 == 0); const bool use_vec16 = (hidden_dim % 16 == 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment