Unverified Commit feda5b55 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Fuse amax computation into activation kernel (#2004)



* Compute amax in activation kernels when the output pointer is provided, even for non-fp8 outputs
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 9f13fe2fefc58cae93bc467d87d01ecf792a0381)

* Initialize metatensor values
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fuse computation of amax into the activation kernel for fp8 current scaling
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 2b54327ac9c931a5340983a79e99de5caa0399dd)
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Zero out amax in `create_hp_tensor_with_amax` instead of relying on `Float8CurrentScalingQuantizer.__init__` to zero-initialize it
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent f858dc35
...@@ -197,7 +197,8 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -197,7 +197,8 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
} }
} else { } else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name);
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); // Note: amax is supported for non-FP8 output as it can be fused into the computation
// and later used for quantization with no need to compute it separately
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name); "Scale_inv is not supported for non-FP8 input ", name);
......
...@@ -183,6 +183,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -183,6 +183,7 @@ __launch_bounds__(unary_kernel_threads) __global__
VectorizedStorer<OutputType, nvec, aligned> storer(output, N); VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 1; ComputeType s = 1;
const bool requires_amax = (amax != nullptr);
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
...@@ -196,20 +197,20 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -196,20 +197,20 @@ __launch_bounds__(unary_kernel_threads) __global__
for (int i = 0; i < nvec; ++i) { for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader.separate()[i]); const ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
ComputeType temp = OP(val, p); ComputeType temp = OP(val, p);
if constexpr (is_fp8<OutputType>::value) { if (requires_amax) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max); max = fmaxf(fabsf(temp), max);
}
if constexpr (is_fp8<OutputType>::value) {
temp = temp * s; temp = temp * s;
} }
storer.separate()[i] = static_cast<OutputType>(temp); storer.separate()[i] = static_cast<OutputType>(temp);
} }
storer.store(tid, N); storer.store(tid, N);
} }
if constexpr (is_fp8<OutputType>::value) {
// Reduce amax over block // Reduce amax over block
if (amax != nullptr) { if (requires_amax) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value); static_assert(std::is_same<ComputeType, float>::value);
...@@ -217,6 +218,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -217,6 +218,7 @@ __launch_bounds__(unary_kernel_threads) __global__
} }
} }
if constexpr (is_fp8<OutputType>::value) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s); reciprocal<ComputeType>(scale_inv, s);
...@@ -236,6 +238,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -236,6 +238,7 @@ __launch_bounds__(unary_kernel_threads) __global__
VectorizedStorer<OutputType, nvec, aligned> storer(output, N); VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 1; ComputeType s = 1;
const bool requires_amax = (amax != nullptr);
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
...@@ -251,10 +254,11 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -251,10 +254,11 @@ __launch_bounds__(unary_kernel_threads) __global__
const ComputeType val = static_cast<ComputeType>(loader.separate()[i]); const ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
const ComputeType g = static_cast<ComputeType>(grad_loader.separate()[i]); const ComputeType g = static_cast<ComputeType>(grad_loader.separate()[i]);
ComputeType temp = OP(val, p) * g; ComputeType temp = OP(val, p) * g;
if constexpr (is_fp8<OutputType>::value) { if (requires_amax) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max); max = fmaxf(fabsf(temp), max);
}
if constexpr (is_fp8<OutputType>::value) {
temp = temp * s; temp = temp * s;
} }
...@@ -262,9 +266,9 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -262,9 +266,9 @@ __launch_bounds__(unary_kernel_threads) __global__
} }
storer.store(tid, N); storer.store(tid, N);
} }
if constexpr (is_fp8<OutputType>::value) {
// Reduce amax over block // Reduce amax over block
if (amax != nullptr) { if (requires_amax) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value); static_assert(std::is_same<ComputeType, float>::value);
...@@ -272,6 +276,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -272,6 +276,7 @@ __launch_bounds__(unary_kernel_threads) __global__
} }
} }
if constexpr (is_fp8<OutputType>::value) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s); reciprocal<ComputeType>(scale_inv, s);
...@@ -406,6 +411,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -406,6 +411,7 @@ __launch_bounds__(unary_kernel_threads) __global__
const size_t M = num_aligned_elements * m; const size_t M = num_aligned_elements * m;
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 1; ComputeType s = 1;
const bool requires_amax = (amax != nullptr);
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
...@@ -425,18 +431,20 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -425,18 +431,20 @@ __launch_bounds__(unary_kernel_threads) __global__
const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]); const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]); const ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2); ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2);
if constexpr (is_fp8<OutputType>::value) { if (requires_amax) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max); max = fmaxf(fabsf(temp), max);
}
if constexpr (is_fp8<OutputType>::value) {
temp = temp * s; temp = temp * s;
} }
storer.separate()[i] = static_cast<OutputType>(static_cast<ComputeType>(temp)); storer.separate()[i] = static_cast<OutputType>(static_cast<ComputeType>(temp));
} }
storer.store(id_x, n); storer.store(id_x, n);
} }
if constexpr (is_fp8<OutputType>::value) {
// Reduce amax over block // Reduce amax over block
if (amax != nullptr) { if (requires_amax) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value); static_assert(std::is_same<ComputeType, float>::value);
...@@ -444,6 +452,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -444,6 +452,7 @@ __launch_bounds__(unary_kernel_threads) __global__
} }
} }
if constexpr (is_fp8<OutputType>::value) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s); reciprocal<ComputeType>(scale_inv, s);
...@@ -497,6 +506,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -497,6 +506,7 @@ __launch_bounds__(unary_kernel_threads) __global__
const size_t M = num_aligned_elements * m; const size_t M = num_aligned_elements * m;
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 1; ComputeType s = 1;
const bool requires_amax = (amax != nullptr);
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
...@@ -524,11 +534,13 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -524,11 +534,13 @@ __launch_bounds__(unary_kernel_threads) __global__
ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;
ComputeType after_dgate = grad_val * Activation(gelu_in, p); ComputeType after_dgate = grad_val * Activation(gelu_in, p);
if constexpr (is_fp8<OutputType>::value) { if (requires_amax) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
max = fmaxf(fabsf(after_dgelu), max); max = fmaxf(fabsf(after_dgelu), max);
after_dgelu = after_dgelu * s;
max = fmaxf(fabsf(after_dgate), max); max = fmaxf(fabsf(after_dgate), max);
}
if constexpr (is_fp8<OutputType>::value) {
after_dgelu = after_dgelu * s;
after_dgate = after_dgate * s; after_dgate = after_dgate * s;
} }
...@@ -538,9 +550,9 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -538,9 +550,9 @@ __launch_bounds__(unary_kernel_threads) __global__
storer0.store(id_x, n); storer0.store(id_x, n);
storer1.store(id_x, n); storer1.store(id_x, n);
} }
if constexpr (is_fp8<OutputType>::value) {
// Reduce amax over block // Reduce amax over block
if (amax != nullptr) { if (requires_amax) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value); static_assert(std::is_same<ComputeType, float>::value);
...@@ -548,6 +560,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -548,6 +560,7 @@ __launch_bounds__(unary_kernel_threads) __global__
} }
} }
if constexpr (is_fp8<OutputType>::value) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s); reciprocal<ComputeType>(scale_inv, s);
......
...@@ -194,10 +194,26 @@ class Float8CurrentScalingQuantizer : public Quantizer { ...@@ -194,10 +194,26 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override; DType dtype) const override;
/*! @brief Construct a high precision tensor giving it this quantizer's amax
Note: this member function also zeros out the amax, as it is meant to be used in conjunction with
a kernel computing the amax, which might expect the amax to be initialized to zero
*/
std::pair<TensorWrapper, py::object> create_hp_tensor_with_amax(const std::vector<size_t>& shape,
DType dtype);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override; std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out, void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override; const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
/*! @brief Convert to a quantized data format avoiding amax computation */
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt);
private:
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
}; };
class Float8BlockQuantizer : public Quantizer { class Float8BlockQuantizer : public Quantizer {
......
...@@ -32,8 +32,17 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int ...@@ -32,8 +32,17 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
// Compute activation directly // Compute activation directly
NVTE_SCOPED_GIL_RELEASE( NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); }); { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); });
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation in high-precision fused together with amax, then quantize.
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); });
quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp);
} else { } else {
// Compute activation in high-precision, then quantize // Compute activation in high-precision, then quantize
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE( NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); });
...@@ -70,6 +79,15 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i ...@@ -70,6 +79,15 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i
dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
}); });
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation backward in high-precision fused together with amax, then quantize.
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp);
} else { } else {
// Compute activation backward in high-precision, then quantize // Compute activation backward in high-precision, then quantize
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
......
...@@ -397,6 +397,15 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso ...@@ -397,6 +397,15 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
return {std::move(out_cpp), std::move(out_py)}; return {std::move(out_cpp), std::move(out_py)};
} }
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_hp_tensor_with_amax(
const std::vector<size_t>& shape, DType dtype) {
amax.zero_();
auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype);
out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::convert_and_update_tensor( std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::convert_and_update_tensor(
py::object tensor) const { py::object tensor) const {
NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()),
...@@ -489,8 +498,9 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::convert_and_ ...@@ -489,8 +498,9 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::convert_and_
return {std::move(out_cpp), std::move(tensor)}; return {std::move(out_cpp), std::move(tensor)};
} }
void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) { const std::optional<TensorWrapper>& noop_flag,
bool compute_amax) {
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
// Nothing to be done if input is empty // Nothing to be done if input is empty
...@@ -507,7 +517,9 @@ void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorW ...@@ -507,7 +517,9 @@ void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorW
quant_config.set_amax_epsilon(amax_epsilon); quant_config.set_amax_epsilon(amax_epsilon);
// Compute amax // Compute amax
if (compute_amax) {
NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); }); NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); });
}
// Perform amax reduction if needed // Perform amax reduction if needed
if (with_amax_reduction) { if (with_amax_reduction) {
...@@ -526,6 +538,19 @@ void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorW ...@@ -526,6 +538,19 @@ void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorW
NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); });
} }
void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
this->quantize_impl(input, out, noop_flag, true);
}
void Float8CurrentScalingQuantizer::quantize_with_amax(
TensorWrapper& input, TensorWrapper& out, const std::optional<TensorWrapper>& noop_flag) {
NVTE_CHECK(input.get_amax().data_ptr == amax.data_ptr(),
"Input does not use the appropriate amax tensor");
input.set_amax(nullptr, DType::kFloat32, input.defaultShape);
this->quantize_impl(input, out, noop_flag, false);
}
Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>(); this->dtype = quantizer.attr("dtype").cast<DType>();
this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast<int>(); this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast<int>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment