Unverified Commit f0572aa5 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Fix bugs from refactoring C++ tensor class (#2481)



Remve assumption in quantize/activation kernels that data buffer is initialized
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 50be0299
...@@ -27,9 +27,9 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t ...@@ -27,9 +27,9 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t
switch (input.scaling_mode) { switch (input.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match.");
fp8::dequantize(input, output, stream); fp8::dequantize(input, output, stream);
break; break;
} }
......
...@@ -98,8 +98,8 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte ...@@ -98,8 +98,8 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
const size_t rows = gated_input.flat_first_dim(); const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2; const size_t cols = gated_input.flat_last_dim() / 2;
NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision.");
NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match.");
NVTE_CHECK(grad.flat_first_dim() == rows, NVTE_CHECK(grad.flat_first_dim() == rows,
"Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [",
...@@ -116,9 +116,9 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte ...@@ -116,9 +116,9 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
NVTE_CHECK(output->flat_last_dim() == cols * 2, NVTE_CHECK(output->flat_last_dim() == cols * 2,
"Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [",
output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(gated_input.data.shape == output->data.shape, NVTE_CHECK(gated_input.shape() == output->shape(),
"Gated input and output shapes must match. Input shape: ", gated_input.data.shape, "Gated input and output shapes must match. Input shape: ", gated_input.shape(),
", output shape: ", output->data.shape, "."); ", output shape: ", output->shape(), ".");
switch (output->scaling_mode) { switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
......
...@@ -227,8 +227,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) ...@@ -227,8 +227,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
bool use_colwise_scaling = input.has_columnwise_data(); bool use_colwise_scaling = input.has_columnwise_data();
checkCuDriverContext(stream); checkCuDriverContext(stream);
const auto &input_shape = input.data.shape; NVTE_CHECK(input.dim() >= 2, "Input must have at least 2 dimensions.");
NVTE_CHECK(input_shape.size() >= 2, "Input must have at least 2 dimensions.");
if (use_rowwise_scaling) { if (use_rowwise_scaling) {
NVTE_CHECK(input.has_data(), "Cannot dequantize tensor without rowwise data."); NVTE_CHECK(input.has_data(), "Cannot dequantize tensor without rowwise data.");
...@@ -241,7 +240,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) ...@@ -241,7 +240,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
} }
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match.");
// TODO: Make more general // TODO: Make more general
const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment