"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "4478b044b5649cc55da62fa02ed501607d46212c"
Unverified Commit 78e097f1 authored by Ace Eldeib's avatar Ace Eldeib Committed by GitHub
Browse files

[Jax] Fix narrowing conversions (#2094)


Signed-off-by: default avatarAce Eldeib <alexeldeib@gmail.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent d88137c4
...@@ -37,9 +37,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -37,9 +37,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto is_2x = static_cast<bool>(is_2x_int); auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
auto input_shape = std::vector<size_t>{m, act_len * n}; auto input_shape = std::vector<size_t>{m, static_cast<size_t>(act_len * n)};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, static_cast<size_t>(n)};
auto output_trans_shape = std::vector<size_t>{n, m}; auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n), m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype)); auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
...@@ -253,11 +253,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -253,11 +253,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto m = product(act_input_dims, 0, act_input_dims.size() - 2); auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = input_dims.back(); auto n = input_dims.back();
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, static_cast<size_t>(n)};
auto act_input_shape = std::vector<size_t>{m, n * act_len}; auto act_input_shape = std::vector<size_t>{m, static_cast<size_t>(n * act_len)};
auto output_shape = std::vector<size_t>{m, n * act_len}; auto output_shape = std::vector<size_t>{m, static_cast<size_t>(n * act_len)};
auto output_trans_shape = std::vector<size_t>{n * act_len, m}; auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n * act_len), m};
auto dbias_shape = std::vector<size_t>{n * act_len}; auto dbias_shape = std::vector<size_t>{static_cast<size_t>(n * act_len)};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end()); std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = auto input_tensor =
......
...@@ -118,7 +118,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -118,7 +118,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{ std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1), product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()}); static_cast<size_t>(scale_inv_buf->dimensions().back())});
} }
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
...@@ -135,7 +135,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -135,7 +135,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1), colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()}); static_cast<size_t>(colwise_scale_inv_buf->dimensions().back())});
} }
if (_norm_type == NVTE_Norm_Type::LayerNorm) { if (_norm_type == NVTE_Norm_Type::LayerNorm) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment