Unverified Commit 4478b044 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Replace `int8_t` in Pybind11 extensions with `int64_t` (#882)



Replace int8_t in PyTorch extensions with int64_t
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent e9606077
...@@ -400,7 +400,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -400,7 +400,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
int64_t fp8_tensor, int64_t fp8_tensor,
int64_t otype, int64_t otype,
const int8_t sm_margin, const int64_t sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype); transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps); float eps_float = static_cast<float>(eps);
...@@ -424,7 +424,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, ...@@ -424,7 +424,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
double eps, double eps,
const int8_t sm_margin, const int64_t sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps); float eps_float = static_cast<float>(eps);
...@@ -447,7 +447,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -447,7 +447,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
int64_t fp8_tensor, int64_t fp8_tensor,
int64_t otype, int64_t otype,
const int8_t sm_margin, const int64_t sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype); transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps); float eps_float = static_cast<float>(eps);
...@@ -469,7 +469,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -469,7 +469,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
double eps, double eps,
const int8_t sm_margin, const int64_t sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps); float eps_float = static_cast<float>(eps);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment