Unverified Commit cb5013bd authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Refactor C++ quantizer infrastructure (#1952)



* remove reciprocal op
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Refactor Quantizer::create_tensor function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix bug when constructing FP8 tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add quantize function to C++ quantizers
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Prototype function to coerce Python quantized tensors to match quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use quantizer class in tex.quantize
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 current scaling support for activation backward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable quantized GEMM output with FP8 current scaling
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add coerce_tensor functions for MXFP8 and DSv3
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Avoid quantizing empty tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use consistent shapes for FP8 transposes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* In attention impl, construct FP8 tensors with pre-initialized scale-invs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initialize MXFP8 scales to zero
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Store copy of quantizer when creating quantized tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure quantized tensors have private quantizer

Avoid problems with in-place ops after quantizer usages are changed externally.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename "coerce_tensor" to "convert_and_update_tensor"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make sure CUDA context is available when launching NVRTC kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expose CUDA context creation function externally
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5a495a39
......@@ -837,10 +837,9 @@ class TestBasicOps:
pytest.skip("FP8 output is only supported with FP8 GEMMs")
if quantized_grad_input and not quantized_compute:
pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
if quantization == "mxfp8" and quantized_output:
pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
if quantization == "mxfp8" and quantized_grad_input:
pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs")
if quantization not in (None, "fp8"):
if quantized_output or quantized_grad_input:
pytest.skip("Recipe does not support quantized GEMM output")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
......
......@@ -8,6 +8,7 @@
transformer_engine::cuda::stream_priority_range*;
transformer_engine::cuda::current_device*;
transformer_engine::cuda_driver::get_symbol*;
transformer_engine::cuda_driver::ensure_context_exists*;
transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*;
......
......@@ -44,6 +44,19 @@ void *get_symbol(const char *symbol, int cuda_version) {
return entry_point;
}
void ensure_context_exists() {
CUcontext context;
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context);
if (context == nullptr) {
// Add primary context to context stack
CUdevice device;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device());
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device);
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device);
}
}
} // namespace cuda_driver
} // namespace transformer_engine
......@@ -39,6 +39,14 @@ inline CUresult call(const char *symbol, ArgTs... args) {
return (*func)(args...);
}
/*! \brief Ensure that the calling thread has a CUDA context
*
* Each thread maintains a stack of CUDA contexts. If the calling
* thread has an empty stack, the primary context is added to the
* stack.
*/
void ensure_context_exists();
} // namespace cuda_driver
} // namespace transformer_engine
......
......@@ -59,6 +59,7 @@ class Kernel {
template <typename... ArgTs>
void launch(int device_id, const dim3 grid_dim, const dim3 block_dim,
unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
cuda_driver::ensure_context_exists();
void *arg_ptrs[] = {const_cast<void *>(static_cast<const void *>(&args))...};
NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y,
grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes,
......
......@@ -12,7 +12,7 @@
namespace transformer_engine::pytorch {
std::vector<size_t> getTensorShape(at::Tensor t) {
std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape;
for (auto s : t.sizes()) {
shape.push_back(s);
......
......@@ -98,9 +98,21 @@ class Quantizer {
virtual void set_quantization_params(TensorWrapper* tensor) const = 0;
virtual std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const = 0;
/*! @brief Construct a tensor with uninitialized data */
virtual std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const = 0;
/*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor
*
* The PyTorch tensor's attributes are modified to match the
* quantizer's configuration.
*/
virtual std::pair<TensorWrapper, py::object> convert_and_update_tensor(
py::object tensor) const = 0;
/*! @brief Convert to a quantized data format */
virtual void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) = 0;
virtual ~Quantizer() = default;
......@@ -121,9 +133,17 @@ class NoneQuantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override {}
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
/*! @brief Construct a tensor with pre-initialized data */
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
at::Tensor data) const;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object tensor) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
};
class Float8Quantizer : public Quantizer {
......@@ -139,9 +159,19 @@ class Float8Quantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
/*! @brief Construct a tensor with pre-initialized data */
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> data,
std::optional<at::Tensor> transpose,
std::optional<at::Tensor> scale_inv) const;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
};
class Float8CurrentScalingQuantizer : public Quantizer {
......@@ -161,9 +191,13 @@ class Float8CurrentScalingQuantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
};
class Float8BlockQuantizer : public Quantizer {
......@@ -195,9 +229,13 @@ class Float8BlockQuantizer : public Quantizer {
// Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
};
......@@ -212,16 +250,20 @@ class MXFP8Quantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
};
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);
std::vector<size_t> getTensorShape(at::Tensor t);
std::vector<size_t> getTensorShape(const at::Tensor& t);
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe);
......
......@@ -13,87 +13,74 @@ namespace transformer_engine::pytorch {
template <void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t)>
py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) {
init_extension();
auto my_quantizer = convert_quantizer(quantizer);
auto input_tensor = input.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
const auto& te_input_shape = te_input.shape();
std::vector<size_t> input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim);
input_shape[input_shape.size() - 1] /= shape_divisor;
auto fake_tensor_type = input.scalar_type();
auto [te_output, out] =
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
// for current scaling, we need to compute amax first and then quantize
// because cache cannot fit in the entire tensor to compute amax and quantize
// the quantizer should not need amax reduction, no process group needed here
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// activation function might change the input data range, we need to first call the activation function
// and then find the amax and scale of that and then do the quantization
// get a NoneQuantizer to calculate amax of activation output
auto my_quantizer_none = std::make_unique<NoneQuantizer>(py::none());
auto [te_output_act, out_act] =
my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
NVTE_SCOPED_GIL_RELEASE({
act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream());
// use te_output_act as input to the compute amax and find the amax of activated tensor
nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
});
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
if (my_quantizer_cs->with_amax_reduction) {
NVTE_ERROR(
"per-tensor current scaling amax reduction is not supported in activation functions.");
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet.");
// Input tensor
auto input_tensor = input.contiguous();
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);
// Construct output tensor
auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape = input_cpp.shape();
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
output_shape.back() /= shape_divisor;
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype);
// Compute activation
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation directly
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); });
} else {
// Compute activation in high-precision, then quantize
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); });
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); });
quantizer_cpp->quantize(temp_cpp, out_cpp);
}
return out;
return out_py;
}
template <void (*act_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)>
py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input,
template <void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)>
py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input,
py::handle quantizer) {
init_extension();
auto my_quantizer = convert_quantizer(quantizer);
auto input_tensor = input.contiguous();
auto grad_tensor = grad.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor);
const auto& te_input_shape = te_input.shape();
std::vector<size_t> input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim);
auto fake_tensor_type = input.scalar_type();
auto [te_output, out] =
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
// Grad output and input tensors
auto grad_output_tensor = grad_output.contiguous();
auto input_tensor = input.contiguous();
const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor);
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);
// Construct grad input tensor
auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape_te = input_cpp.shape();
const std::vector<size_t> input_shape(input_shape_te.data,
input_shape_te.data + input_shape_te.ndim);
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype);
// Compute activation backward
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation backward directly
NVTE_SCOPED_GIL_RELEASE({
act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
} else {
// Compute activation backward in high-precision, then quantize
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
quantizer_cpp->quantize(temp_cpp, grad_input_cpp);
}
return out;
return grad_input_py;
}
py::object gelu(const at::Tensor& input, py::handle quantizer) {
......
......@@ -18,7 +18,7 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
auto max_tokens = shape[0];
auto fcd_size = 1;
for (int i = 1; i <= shape.size(); i++) {
for (size_t i = 1; i <= shape.size(); i++) {
fcd_size *= shape[i];
}
......@@ -103,8 +103,20 @@ std::vector<py::object> fused_attn_fwd(
auto o_shape = std::vector<size_t>{q_shape.begin(), q_shape.end()};
o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1];
py::object o_python, s_python;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// Initialize FP8 tensor with scale-inverse
auto *O_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(O_quantizer.get());
auto *S_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(S_quantizer.get());
NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt,
std::nullopt, std::nullopt);
std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
} else {
std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te);
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
}
auto o_shape_int64 = std::vector<int64_t>{o_shape.begin(), o_shape.end()};
// construct NVTE tensors
......@@ -284,8 +296,20 @@ std::vector<py::object> fused_attn_bwd(
py::object s_python, dp_python;
std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer);
std::unique_ptr<Quantizer> dP_quantizer = convert_quantizer(dp_quantizer);
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
auto *S_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(S_quantizer.get());
auto *dP_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(dP_quantizer.get());
NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
} else {
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32);
}
std::vector<size_t> q_shape = convertShape(te_Q.shape());
std::vector<size_t> k_shape = convertShape(te_K.shape());
......@@ -374,9 +398,22 @@ std::vector<py::object> fused_attn_bwd(
default:
NVTE_ERROR("QKV layout not supported!");
}
std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, dQ);
std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, dK);
std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, dV);
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
auto *fp8_quantizer = dynamic_cast<Float8Quantizer *>(dQKV_quantizer.get());
NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_dQ, py_dQ) =
fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt);
std::tie(te_dK, py_dK) =
fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt);
std::tie(te_dV, py_dV) =
fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt);
} else {
auto *none_quantizer = dynamic_cast<NoneQuantizer *>(dQKV_quantizer.get());
NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8");
std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ);
std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK);
std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV);
}
// construct NVTE tensors
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
......
......@@ -28,60 +28,6 @@ std::vector<size_t> get_tensor_shape(const TensorWrapper &tensor) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
void quantize_impl(const TensorWrapper &input, py::handle &quantizer_py,
std::unique_ptr<Quantizer> &quantizer_cpp, TensorWrapper &output,
TensorWrapper &noop_flag) {
// Check tensor dims
NVTE_CHECK(get_tensor_shape(input) == get_tensor_shape(output),
"Input tensor (shape=", get_tensor_shape(input),
") and output tensor (shape=", get_tensor_shape(output), ") do not match");
if (input.numel() == 0) {
return;
}
// Recipe-specific configuration
QuantizationConfigWrapper quant_config;
quant_config.set_noop_tensor(noop_flag.data());
if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) {
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_SCOPED_GIL_RELEASE(
{ nvte_compute_amax(input.data(), output.data(), at::cuda::getCurrentCUDAStream()); });
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
// this config is used for cs scaling factor computation
// because compute scale is cannot be fused with quantize kernel
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(output.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel
output.set_amax(nullptr, DType::kFloat32, output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer_py.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(quantizer_cpp.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
}
// Perform quantization
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), output.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
}
} // namespace
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
......@@ -101,18 +47,17 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
const auto fake_dtype = input_cpp.dtype();
std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype);
} else {
output_py = output;
output_cpp = makeTransformerEngineTensor(output_py, quantizer);
std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output);
}
// Initialize no-op flag
TensorWrapper noop_flag_cpp;
std::optional<TensorWrapper> noop_flag_cpp;
if (noop_flag.has_value()) {
noop_flag_cpp = makeTransformerEngineTensor(*noop_flag);
}
// Perform quantization
quantize_impl(input_cpp, quantizer, quantizer_cpp, output_cpp, noop_flag_cpp);
quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp);
return output_py;
}
......@@ -182,10 +127,8 @@ void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list,
});
} else {
// Quantize kernels individually
TensorWrapper dummy_noop_flag;
for (size_t i = 0; i < num_tensors; ++i) {
quantize_impl(input_list[i], quantizer_py_list[i], quantizer_cpp_list[i], output_list[i],
dummy_noop_flag);
quantizer_cpp_list[i]->quantize(input_list[i], output_list[i]);
}
}
}
......
......@@ -18,27 +18,35 @@ namespace pytorch {
at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) {
init_extension();
const auto dim = input.dim();
NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose.");
if (input.dim() > 2) {
input = input.view({-1, input.size(dim - 1)});
// Tensor dimensions
const auto shape = getTensorShape(input);
std::vector<int64_t> transpose_shape_int64;
if (shape.size() > 0) {
transpose_shape_int64.push_back(shape.back());
for (size_t i = 0; i < shape.size() - 1; ++i) {
transpose_shape_int64.push_back(shape[i]);
}
}
const size_t M = shape.size() > 0 ? product(shape) / shape.back() : 1;
const size_t N = shape.size() > 0 ? shape.back() : 1;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
// Output tensor
at::Tensor out;
if (output.has_value()) {
out = *output;
} else {
out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
out = at::empty(transpose_shape_int64, opts);
}
if (M == 0 || N == 0) return out;
// Return immediately if tensor is empty
if (M == 0 || N == 0) {
return out;
}
// Compute transpose
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector<size_t>{M, N}, otype);
auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector<size_t>{N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return out;
......
......@@ -22,8 +22,7 @@ from ...distributed import (
from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
......@@ -480,18 +479,11 @@ class BasicLinear(BasicOperation):
raise ValueError("Output tensor is quantized, but quantizer was not provided")
else:
output_quantizer = None
if isinstance(output_quantizer, MXFP8Quantizer):
raise RuntimeError(
"Attempting to generate MXFP8 output tensor, "
"but GEMM with MXFP8 output is not supported"
)
if isinstance(output_quantizer, Float8BlockQuantizer):
if output_quantizer is not None:
if not isinstance(output_quantizer, Float8Quantizer):
raise RuntimeError(
"Attempting to generate Float8BlockQuantized output tensor, "
"but GEMM with Float8BlockQuantized output is not supported"
"Attempting to generate quantized output tensor with unsupported quantizer"
)
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Check if accumulating into output tensor
......@@ -765,10 +757,11 @@ class BasicLinear(BasicOperation):
)
else:
grad_input_quantizer = None
if isinstance(grad_input_quantizer, MXFP8Quantizer):
if grad_input_quantizer is not None:
if not isinstance(grad_input_quantizer, Float8Quantizer):
raise RuntimeError(
"Attempting to generate MXFP8 grad input tensor, "
"but GEMM with MXFP8 output is not supported"
"Attempting to generate quantized grad input tensor "
"with unsupported quantizer"
)
# Check if accumulating into grad input tensor
......
......@@ -182,7 +182,7 @@ class UserbuffersForwardLinear(FusedOperation):
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
if output_quantizer is not None:
raise ValueError("FP8 output is not supported")
raise ValueError("Quantized output is not supported")
else:
input_quantizer = None
weight_quantizer = None
......
......@@ -59,7 +59,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
......
......@@ -86,7 +86,7 @@ class Float8TensorBase(QuantizedTensorBase):
else:
instance = super().__new__(cls, *args, **kwargs)
instance._data = data
instance._quantizer = quantizer
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._scale_inv = fp8_scale_inv
instance._transpose = data_transpose
......
......@@ -83,7 +83,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
......
......@@ -521,7 +521,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor):
dst._rowwise_data = src._rowwise_data
dst._columnwise_data = src._columnwise_data
dst._quantizer = src._quantizer
dst._quantizer = src._quantizer.copy()
dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv
......
......@@ -108,10 +108,9 @@ class Float8Quantizer(Quantizer):
# Allocate FP8 data transpose if needed
data_transpose = None
if self.columnwise_usage:
inner_dim = data.size(-1)
transpose_shape = [data.size(-1)] + list(data.shape[:-1])
data_transpose = torch.empty(
inner_dim,
data.numel() // inner_dim,
transpose_shape,
dtype=torch.uint8,
device=device,
)
......@@ -230,7 +229,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax_epsilon: float = 0.0,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.scale = torch.ones(1, dtype=torch.float32, device=device)
self.scale = torch.empty(1, dtype=torch.float32, device=device)
self.amax = torch.empty(1, dtype=torch.float32, device=device)
self.dtype = fp8_dtype
self.with_amax_reduction = with_amax_reduction
......@@ -690,7 +689,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# Float8Tensor attributes
self._data = tensor._data
self._quantizer = tensor._quantizer
self._quantizer = tensor._quantizer.copy()
self._fp8_dtype = tensor._fp8_dtype
self._scale_inv = tensor._scale_inv
self._transpose = tensor._transpose
......
......@@ -433,7 +433,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor)
self._rowwise_data = tensor._rowwise_data
self._columnwise_data = tensor._columnwise_data
self._quantizer = tensor._quantizer
self._quantizer = tensor._quantizer.copy()
self._fp8_dtype = tensor._fp8_dtype
self._rowwise_scale_inv = tensor._rowwise_scale_inv
self._columnwise_scale_inv = tensor._columnwise_scale_inv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment