Unverified Commit 195a09f5 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix bmm fp8 (#4926)

parent 9fccda31
......@@ -82,7 +82,10 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From FlashInfer
*/
m.def("bmm_fp8", bmm_fp8);
m.def(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
"cublas_handle, int cuda_stream) -> ()");
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
m.def("min_p_sampling_from_probs", min_p_sampling_from_probs);
m.def("top_k_renorm_probs", top_k_renorm_probs);
m.def("top_p_renorm_probs", top_p_renorm_probs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment