Commit 0dc4c8e9 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip rms_norm_dynamic_per_token_quant

parent d9ef7ce7
......@@ -68,13 +68,13 @@ void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
// torch::Tensor& weight,
// torch::Tensor& scale, double epsilon);
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor const& weight,
torch::Tensor& scales,
double const epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual);
// void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
// torch::Tensor const& input,
// torch::Tensor const& weight,
// torch::Tensor& scales,
// double const epsilon,
// std::optional<torch::Tensor> scale_ub,
// std::optional<torch::Tensor> residual);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
......
......@@ -142,12 +142,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// &fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels
ops.def(
"rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
"Tensor weight, Tensor! scale, float epsilon, "
"Tensor? scale_ub, Tensor!? residual) -> ()");
ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
&rms_norm_dynamic_per_token_quant);
// ops.def(
// "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
// "Tensor weight, Tensor! scale, float epsilon, "
// "Tensor? scale_ub, Tensor!? residual) -> ()");
// ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
// &rms_norm_dynamic_per_token_quant);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment