Commit e8700643 authored by zhuwenwen's avatar zhuwenwen
Browse files

update csrc

parent 88443051
......@@ -257,12 +257,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& input_global_scale);
#endif
void silu_mul_fp8_quant_deep_gemm_cuda(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens);
// void silu_mul_fp8_quant_deep_gemm_cuda(
// const at::Tensor& input, // (E, T, 2*H)
// const at::Tensor& counts, // (E)
// at::Tensor& y_q, // (E, T, H) [OUT]
// at::Tensor& y_s, // (E, T, H//group_size) [OUT]
// int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
......
......@@ -32,12 +32,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#define stride_tag
#endif
ops.def(
"silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! "
"y_q, Tensor! y_s, int group_size, "
"bool use_ue8m0, int num_parallel_tokens) -> ()");
ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA,
&silu_mul_fp8_quant_deep_gemm_cuda);
// ops.def(
// "silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! "
// "y_q, Tensor! y_s, int group_size, "
// "bool use_ue8m0, int num_parallel_tokens) -> ()");
// ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA,
// &silu_mul_fp8_quant_deep_gemm_cuda);
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment