Unverified Commit cb5013bd authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Refactor C++ quantizer infrastructure (#1952)



* remove reciprocal op
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Refactor Quantizer::create_tensor function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix bug when constructing FP8 tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add quantize function to C++ quantizers
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Prototype function to coerce Python quantized tensors to match quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use quantizer class in tex.quantize
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 current scaling support for activation backward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable quantized GEMM output with FP8 current scaling
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add coerce_tensor functions for MXFP8 and DSv3
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Avoid quantizing empty tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use consistent shapes for FP8 transposes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* In attention impl, construct FP8 tensors with pre-initialized scale-invs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initialize MXFP8 scales to zero
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Store copy of quantizer when creating quantized tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure quantized tensors have private quantizer

Avoid problems with in-place ops after quantizer usages are changed externally.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename "coerce_tensor" to "convert_and_update_tensor"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make sure CUDA context is available when launching NVRTC kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expose CUDA context creation function externally
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5a495a39
......@@ -837,10 +837,9 @@ class TestBasicOps:
pytest.skip("FP8 output is only supported with FP8 GEMMs")
if quantized_grad_input and not quantized_compute:
pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
if quantization == "mxfp8" and quantized_output:
pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
if quantization == "mxfp8" and quantized_grad_input:
pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs")
if quantization not in (None, "fp8"):
if quantized_output or quantized_grad_input:
pytest.skip("Recipe does not support quantized GEMM output")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
......
......@@ -8,6 +8,7 @@
transformer_engine::cuda::stream_priority_range*;
transformer_engine::cuda::current_device*;
transformer_engine::cuda_driver::get_symbol*;
transformer_engine::cuda_driver::ensure_context_exists*;
transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*;
......
......@@ -44,6 +44,19 @@ void *get_symbol(const char *symbol, int cuda_version) {
return entry_point;
}
void ensure_context_exists() {
CUcontext context;
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context);
if (context == nullptr) {
// Add primary context to context stack
CUdevice device;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device());
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device);
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device);
}
}
} // namespace cuda_driver
} // namespace transformer_engine
......@@ -39,6 +39,14 @@ inline CUresult call(const char *symbol, ArgTs... args) {
return (*func)(args...);
}
/*! \brief Ensure that the calling thread has a CUDA context
*
* Each thread maintains a stack of CUDA contexts. If the calling
* thread has an empty stack, the primary context is added to the
* stack.
*/
void ensure_context_exists();
} // namespace cuda_driver
} // namespace transformer_engine
......
......@@ -59,6 +59,7 @@ class Kernel {
template <typename... ArgTs>
void launch(int device_id, const dim3 grid_dim, const dim3 block_dim,
unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
cuda_driver::ensure_context_exists();
void *arg_ptrs[] = {const_cast<void *>(static_cast<const void *>(&args))...};
NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y,
grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes,
......
......@@ -12,7 +12,7 @@
namespace transformer_engine::pytorch {
std::vector<size_t> getTensorShape(at::Tensor t) {
std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape;
for (auto s : t.sizes()) {
shape.push_back(s);
......
......@@ -98,9 +98,21 @@ class Quantizer {
virtual void set_quantization_params(TensorWrapper* tensor) const = 0;
virtual std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const = 0;
/*! @brief Construct a tensor with uninitialized data */
virtual std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const = 0;
/*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor
*
* The PyTorch tensor's attributes are modified to match the
* quantizer's configuration.
*/
virtual std::pair<TensorWrapper, py::object> convert_and_update_tensor(
py::object tensor) const = 0;
/*! @brief Convert to a quantized data format */
virtual void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) = 0;
virtual ~Quantizer() = default;
......@@ -121,9 +133,17 @@ class NoneQuantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override {}
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
/*! @brief Construct a tensor with pre-initialized data */
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
at::Tensor data) const;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object tensor) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
};
class Float8Quantizer : public Quantizer {
......@@ -139,9 +159,19 @@ class Float8Quantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
/*! @brief Construct a tensor with pre-initialized data */
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> data,
std::optional<at::Tensor> transpose,
std::optional<at::Tensor> scale_inv) const;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
};
class Float8CurrentScalingQuantizer : public Quantizer {
......@@ -161,9 +191,13 @@ class Float8CurrentScalingQuantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
};
class Float8BlockQuantizer : public Quantizer {
......@@ -195,9 +229,13 @@ class Float8BlockQuantizer : public Quantizer {
// Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
};
......@@ -212,16 +250,20 @@ class MXFP8Quantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
};
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);
std::vector<size_t> getTensorShape(at::Tensor t);
std::vector<size_t> getTensorShape(const at::Tensor& t);
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe);
......
......@@ -13,87 +13,74 @@ namespace transformer_engine::pytorch {
template <void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t)>
py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) {
init_extension();
auto my_quantizer = convert_quantizer(quantizer);
auto input_tensor = input.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
const auto& te_input_shape = te_input.shape();
std::vector<size_t> input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim);
input_shape[input_shape.size() - 1] /= shape_divisor;
auto fake_tensor_type = input.scalar_type();
auto [te_output, out] =
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
// for current scaling, we need to compute amax first and then quantize
// because cache cannot fit in the entire tensor to compute amax and quantize
// the quantizer should not need amax reduction, no process group needed here
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// activation function might change the input data range, we need to first call the activation function
// and then find the amax and scale of that and then do the quantization
// get a NoneQuantizer to calculate amax of activation output
auto my_quantizer_none = std::make_unique<NoneQuantizer>(py::none());
auto [te_output_act, out_act] =
my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
NVTE_SCOPED_GIL_RELEASE({
act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream());
// use te_output_act as input to the compute amax and find the amax of activated tensor
nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
});
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
if (my_quantizer_cs->with_amax_reduction) {
NVTE_ERROR(
"per-tensor current scaling amax reduction is not supported in activation functions.");
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet.");
// Input tensor
auto input_tensor = input.contiguous();
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);
// Construct output tensor
auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape = input_cpp.shape();
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
output_shape.back() /= shape_divisor;
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype);
// Compute activation
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation directly
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); });
} else {
// Compute activation in high-precision, then quantize
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); });
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); });
quantizer_cpp->quantize(temp_cpp, out_cpp);
}
return out;
return out_py;
}
template <void (*act_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)>
py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input,
template <void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)>
py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input,
py::handle quantizer) {
init_extension();
auto my_quantizer = convert_quantizer(quantizer);
auto input_tensor = input.contiguous();
auto grad_tensor = grad.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor);
const auto& te_input_shape = te_input.shape();
std::vector<size_t> input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim);
auto fake_tensor_type = input.scalar_type();
auto [te_output, out] =
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
// Grad output and input tensors
auto grad_output_tensor = grad_output.contiguous();
auto input_tensor = input.contiguous();
const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor);
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);
// Construct grad input tensor
auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape_te = input_cpp.shape();
const std::vector<size_t> input_shape(input_shape_te.data,
input_shape_te.data + input_shape_te.ndim);
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype);
// Compute activation backward
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation backward directly
NVTE_SCOPED_GIL_RELEASE({
act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
} else {
// Compute activation backward in high-precision, then quantize
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
quantizer_cpp->quantize(temp_cpp, grad_input_cpp);
}
return out;
return grad_input_py;
}
py::object gelu(const at::Tensor& input, py::handle quantizer) {
......
......@@ -18,7 +18,7 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
auto max_tokens = shape[0];
auto fcd_size = 1;
for (int i = 1; i <= shape.size(); i++) {
for (size_t i = 1; i <= shape.size(); i++) {
fcd_size *= shape[i];
}
......@@ -103,8 +103,20 @@ std::vector<py::object> fused_attn_fwd(
auto o_shape = std::vector<size_t>{q_shape.begin(), q_shape.end()};
o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1];
py::object o_python, s_python;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// Initialize FP8 tensor with scale-inverse
auto *O_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(O_quantizer.get());
auto *S_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(S_quantizer.get());
NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt,
std::nullopt, std::nullopt);
std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
} else {
std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te);
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
}
auto o_shape_int64 = std::vector<int64_t>{o_shape.begin(), o_shape.end()};
// construct NVTE tensors
......@@ -284,8 +296,20 @@ std::vector<py::object> fused_attn_bwd(
py::object s_python, dp_python;
std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer);
std::unique_ptr<Quantizer> dP_quantizer = convert_quantizer(dp_quantizer);
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
auto *S_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(S_quantizer.get());
auto *dP_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(dP_quantizer.get());
NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
} else {
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32);
}
std::vector<size_t> q_shape = convertShape(te_Q.shape());
std::vector<size_t> k_shape = convertShape(te_K.shape());
......@@ -374,9 +398,22 @@ std::vector<py::object> fused_attn_bwd(
default:
NVTE_ERROR("QKV layout not supported!");
}
std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, dQ);
std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, dK);
std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, dV);
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
auto *fp8_quantizer = dynamic_cast<Float8Quantizer *>(dQKV_quantizer.get());
NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_dQ, py_dQ) =
fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt);
std::tie(te_dK, py_dK) =
fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt);
std::tie(te_dV, py_dV) =
fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt);
} else {
auto *none_quantizer = dynamic_cast<NoneQuantizer *>(dQKV_quantizer.get());
NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8");
std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ);
std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK);
std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV);
}
// construct NVTE tensors
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
......
......@@ -28,60 +28,6 @@ std::vector<size_t> get_tensor_shape(const TensorWrapper &tensor) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
void quantize_impl(const TensorWrapper &input, py::handle &quantizer_py,
std::unique_ptr<Quantizer> &quantizer_cpp, TensorWrapper &output,
TensorWrapper &noop_flag) {
// Check tensor dims
NVTE_CHECK(get_tensor_shape(input) == get_tensor_shape(output),
"Input tensor (shape=", get_tensor_shape(input),
") and output tensor (shape=", get_tensor_shape(output), ") do not match");
if (input.numel() == 0) {
return;
}
// Recipe-specific configuration
QuantizationConfigWrapper quant_config;
quant_config.set_noop_tensor(noop_flag.data());
if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) {
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_SCOPED_GIL_RELEASE(
{ nvte_compute_amax(input.data(), output.data(), at::cuda::getCurrentCUDAStream()); });
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
// this config is used for cs scaling factor computation
// because compute scale is cannot be fused with quantize kernel
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(output.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel
output.set_amax(nullptr, DType::kFloat32, output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer_py.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(quantizer_cpp.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
}
// Perform quantization
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), output.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
}
} // namespace
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
......@@ -101,18 +47,17 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
const auto fake_dtype = input_cpp.dtype();
std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype);
} else {
output_py = output;
output_cpp = makeTransformerEngineTensor(output_py, quantizer);
std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output);
}
// Initialize no-op flag
TensorWrapper noop_flag_cpp;
std::optional<TensorWrapper> noop_flag_cpp;
if (noop_flag.has_value()) {
noop_flag_cpp = makeTransformerEngineTensor(*noop_flag);
}
// Perform quantization
quantize_impl(input_cpp, quantizer, quantizer_cpp, output_cpp, noop_flag_cpp);
quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp);
return output_py;
}
......@@ -182,10 +127,8 @@ void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list,
});
} else {
// Quantize kernels individually
TensorWrapper dummy_noop_flag;
for (size_t i = 0; i < num_tensors; ++i) {
quantize_impl(input_list[i], quantizer_py_list[i], quantizer_cpp_list[i], output_list[i],
dummy_noop_flag);
quantizer_cpp_list[i]->quantize(input_list[i], output_list[i]);
}
}
}
......
......@@ -18,27 +18,35 @@ namespace pytorch {
at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) {
init_extension();
const auto dim = input.dim();
NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose.");
if (input.dim() > 2) {
input = input.view({-1, input.size(dim - 1)});
// Tensor dimensions
const auto shape = getTensorShape(input);
std::vector<int64_t> transpose_shape_int64;
if (shape.size() > 0) {
transpose_shape_int64.push_back(shape.back());
for (size_t i = 0; i < shape.size() - 1; ++i) {
transpose_shape_int64.push_back(shape[i]);
}
}
const size_t M = shape.size() > 0 ? product(shape) / shape.back() : 1;
const size_t N = shape.size() > 0 ? shape.back() : 1;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
// Output tensor
at::Tensor out;
if (output.has_value()) {
out = *output;
} else {
out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
out = at::empty(transpose_shape_int64, opts);
}
if (M == 0 || N == 0) return out;
// Return immediately if tensor is empty
if (M == 0 || N == 0) {
return out;
}
// Compute transpose
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector<size_t>{M, N}, otype);
auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector<size_t>{N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return out;
......
......@@ -12,6 +12,27 @@
namespace transformer_engine::pytorch {
namespace {
/*! @brief Transposed tensor shape
*
* The tensor is interpreted as a 2D matrix by flattening all but the
* last dimension, and then transposed.
*/
template <typename T = size_t, typename S = T>
std::vector<T> make_transpose_shape(const std::vector<S>& shape) {
std::vector<T> ret;
if (shape.size() > 0) {
ret.push_back(shape.back());
for (size_t i = 0; i < shape.size() - 1; ++i) {
ret.push_back(shape[i]);
}
}
return ret;
}
} // namespace
constexpr size_t MXFP8_BLOCK_SIZE = 32;
Quantizer::Quantizer(const py::handle& quantizer) {
......@@ -37,24 +58,36 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti
this->dtype = type;
}
std::pair<TensorWrapper, py::object> NoneQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
at::TensorOptions opts;
opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA);
std::vector<int64_t> torch_shape;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
}
at::Tensor ret;
if (rowwise_data.has_value()) {
ret = std::move(*rowwise_data);
} else {
ret = at::empty(torch_shape, opts);
}
std::pair<TensorWrapper, py::object> NoneQuantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype) const {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA);
return create_tensor(shape, dtype, at::empty(shape_int64, opts));
}
TensorWrapper tensor;
tensor.set_rowwise_data(ret.data_ptr(), dtype, shape);
return {std::move(tensor), py::cast(ret)};
std::pair<TensorWrapper, py::object> NoneQuantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype,
at::Tensor data) const {
TensorWrapper out_cpp;
out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape);
set_quantization_params(&out_cpp);
return {std::move(out_cpp), py::cast(data)};
}
std::pair<TensorWrapper, py::object> NoneQuantizer::convert_and_update_tensor(
py::object tensor) const {
auto tensor_pyt = tensor.cast<at::Tensor>();
TensorWrapper out_cpp;
out_cpp.set_rowwise_data(tensor_pyt.data_ptr(),
GetTransformerEngineDType(tensor_pyt.scalar_type()),
getTensorShape(tensor_pyt));
set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void NoneQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
NVTE_ERROR("NoneQuantizer does not support quantization");
}
void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
......@@ -76,68 +109,180 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
}
std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
const std::vector<size_t>& shape, DType dtype) const {
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
at::Tensor scale_inv = at::empty(std::vector<int64_t>{1}, opts);
return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv));
}
std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data,
std::optional<at::Tensor> transpose, std::optional<at::Tensor> scale_inv) const {
using namespace pybind11::literals;
std::vector<int64_t> rowwise_torch_shape;
std::vector<int64_t> columnwise_torch_shape;
if (!shape.empty()) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape.back()));
}
for (size_t i = 0; i < shape.size(); ++i) {
if (i < shape.size() - 1) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
rowwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
at::TensorOptions opts;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
at::Tensor data;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data = std::move(*rowwise_data);
} else {
data = at::empty(rowwise_torch_shape, opts);
// Initialize data tensor
const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
if (with_data && !data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
data = at::empty(shape_int64, opts);
} else if (!with_data && data) {
data.reset();
}
py::object data_py = with_data ? py::cast(*data) : py::none();
// Initialize transpose tensor
const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (with_transpose && !transpose) {
const auto transpose_shape = make_transpose_shape<int64_t>(shape);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
transpose = at::empty(transpose_shape, opts);
} else if (!with_transpose && transpose) {
transpose.reset();
}
const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts);
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
// Initialize scale-inverse tensor
if (!scale_inv) {
scale_inv = at::reciprocal(scale);
}
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
opts = opts.dtype(torch::kFloat32);
// TODO: Replace with an empty tensor.
at::Tensor scale_inv = at::reciprocal(scale);
py::object ret;
// Construct Python FP8 tensor
py::object out_py;
if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
} else {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype),
"data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"data"_a = data_py, "fp8_scale_inv"_a = *scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
}
TensorWrapper tensor(this->get_scaling_mode());
if (rowwise_usage) {
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
// Construct C++ FP8 tensor
TensorWrapper out_cpp(this->get_scaling_mode());
if (with_data) {
out_cpp.set_rowwise_data(data->data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
if (create_transpose) {
std::vector<size_t> transposed_shape;
for (auto s : columnwise_torch_shape) {
transposed_shape.emplace_back(static_cast<size_t>(s));
if (with_transpose) {
const auto transpose_shape = make_transpose_shape(shape);
out_cpp.set_columnwise_data(transpose->data_ptr(), this->dtype, transpose_shape);
out_cpp.set_columnwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> Float8Quantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor.");
// Expected buffers
const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer.");
// Extract buffers from Python tensor
auto data_py = tensor.attr("_data");
auto transpose_py = tensor.attr("_transpose");
const bool has_data = !data_py.is_none();
const bool has_transpose = !transpose_py.is_none();
NVTE_CHECK(has_data || has_transpose, "Float8Tensor has no data.");
std::optional<at::Tensor> data_tensor, transpose_tensor;
if (has_data) {
data_tensor = data_py.cast<at::Tensor>();
}
if (has_transpose) {
transpose_tensor = transpose_py.cast<at::Tensor>();
}
at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast<at::Tensor>();
// Tensor dimensions
std::vector<size_t> shape;
if (has_transpose) {
const auto transpose_shape = getTensorShape(*transpose_tensor);
if (transpose_shape.size() > 0) {
for (size_t i = 1; i < transpose_shape.size(); ++i) {
shape.push_back(transpose_shape[i]);
}
shape.push_back(transpose_shape.front());
}
if (has_data) {
auto expected_shape = getTensorShape(*data_tensor);
NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape,
") and transpose (shape=", transpose_shape, ") do not match");
}
} else { // Already checked has_data == true
shape = getTensorShape(*data_tensor);
}
// Coerce data tensor
if (has_data && !need_data) {
data_tensor.reset();
data_py = py::none();
tensor.attr("_data") = data_py;
} else if (!has_data && need_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
data_tensor = at::empty(shape_int64, opts);
data_py = py::cast(data_tensor);
tensor.attr("_data") = data_py;
}
// Coerce transpose tensor
if (has_transpose && !need_transpose) {
transpose_tensor.reset();
transpose_py = py::none();
tensor.attr("_transpose") = transpose_py;
} else if (!has_transpose && need_transpose) {
const auto transpose_shape = make_transpose_shape<int64_t>(shape);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
transpose_tensor = at::empty(transpose_shape, opts);
transpose_py = py::cast(transpose_tensor);
tensor.attr("_transpose") = transpose_py;
}
tensor.attr("_transpose_invalid") = !need_transpose;
// Coerce other attrs
tensor.attr("_fp8_dtype") = dtype;
// Construct C++ FP8 tensor
TensorWrapper out_cpp;
if (data_tensor) {
out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
if (transpose_tensor) {
const auto transpose_shape = make_transpose_shape(shape);
out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape);
out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void Float8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
if (input.numel() == 0) {
return;
}
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape);
tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
this->set_quantization_params(&tensor);
return {std::move(tensor), std::move(ret)};
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
}
Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer)
......@@ -187,71 +332,198 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso
}
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
const std::vector<size_t>& shape, DType dtype) const {
using namespace pybind11::literals;
std::vector<int64_t> rowwise_torch_shape;
std::vector<int64_t> columnwise_torch_shape;
std::vector<int64_t> scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv
if (!shape.empty()) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape.back()));
}
for (size_t i = 0; i < shape.size(); ++i) {
if (i < shape.size() - 1) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
rowwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
at::TensorOptions opts;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
at::Tensor data;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data = std::move(*rowwise_data);
} else {
data = at::empty(rowwise_torch_shape, opts);
}
// Initialize data tensor
at::Tensor data_tensor;
const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
if (with_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
data_tensor = at::empty(shape_int64, opts);
}
const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts);
// Initialize transpose tensor
at::Tensor transpose_tensor;
const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (with_transpose) {
const auto transpose_shape = make_transpose_shape<int64_t>(shape);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
transpose_tensor = at::empty(transpose_shape, opts);
}
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
// In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set.
at::Tensor scale_inv = at::reciprocal(scale);
// Initialize scale-inverse tensor
at::Tensor scale_inv_tensor;
{
const std::vector<int64_t> scale_inv_shape = {1};
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
scale_inv_tensor = at::empty(scale_inv_shape, opts);
}
py::object ret;
// Construct Python FP8 tensor
py::object out_py;
py::object data_py = with_data ? py::cast(data_tensor) : py::none();
py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none();
if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
} else {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype),
"data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
}
TensorWrapper tensor(this->get_scaling_mode());
if (rowwise_usage) {
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
// Construct C++ FP8 tensor
TensorWrapper out_cpp(this->get_scaling_mode());
if (with_data) {
out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
if (create_transpose) {
std::vector<size_t> transposed_shape;
for (auto s : columnwise_torch_shape) {
transposed_shape.emplace_back(static_cast<size_t>(s));
if (with_transpose) {
const auto transpose_shape = make_transpose_shape(shape);
out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape);
out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape);
tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()),
"Float8CurrentScalingQuantizer must output to Float8Tensor.");
// Expected buffers
const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages.");
// Extract buffers from Python tensor
auto data_py = tensor.attr("_data");
auto transpose_py = tensor.attr("_transpose");
const bool has_data = !data_py.is_none();
const bool has_transpose = !transpose_py.is_none();
NVTE_CHECK(has_data || has_transpose, "Tensor has no data.");
std::optional<at::Tensor> data_tensor, transpose_tensor;
if (has_data) {
data_tensor = data_py.cast<at::Tensor>();
}
if (has_transpose) {
transpose_tensor = transpose_py.cast<at::Tensor>();
}
at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast<at::Tensor>();
// Tensor dimensions
std::vector<size_t> shape;
if (has_transpose) {
const auto transpose_shape = getTensorShape(*transpose_tensor);
if (transpose_shape.size() > 0) {
for (size_t i = 1; i < transpose_shape.size(); ++i) {
shape.push_back(transpose_shape[i]);
}
shape.push_back(transpose_shape.front());
}
if (has_data) {
auto expected_shape = getTensorShape(*data_tensor);
NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape,
") and transpose (shape=", transpose_shape, ") do not match");
}
} else { // Already checked has_data == true
shape = getTensorShape(*data_tensor);
}
// Coerce data tensor in Python tensor
if (has_data && !need_data) {
data_tensor.reset();
data_py = py::none();
tensor.attr("_data") = data_py;
} else if (!has_data && need_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
data_tensor = at::empty(shape_int64, opts);
data_py = py::cast(data_tensor);
tensor.attr("_data") = data_py;
}
// Coerce transpose tensor
if (has_transpose && !need_transpose) {
transpose_tensor.reset();
transpose_py = py::none();
tensor.attr("_transpose") = transpose_py;
} else if (!has_transpose && need_transpose) {
const auto transpose_shape = make_transpose_shape<int64_t>(shape);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
transpose_tensor = at::empty(transpose_shape, opts);
transpose_py = py::cast(transpose_tensor);
tensor.attr("_transpose") = transpose_py;
}
tensor.attr("_transpose_invalid") = !need_transpose;
// Coerce other attrs
tensor.attr("_fp8_dtype") = dtype;
// Construct C++ FP8 tensor
TensorWrapper out_cpp;
if (data_tensor) {
out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
if (transpose_tensor) {
const auto transpose_shape = make_transpose_shape(shape);
out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape);
out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
auto stream = at::cuda::getCurrentCUDAStream();
// Nothing to be done if input is empty
if (input.numel() == 0) {
return;
}
this->set_quantization_params(&tensor);
return {std::move(tensor), std::move(ret)};
// Quantization configs
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
quant_config.set_force_pow_2_scales(force_pow_2_scales);
quant_config.set_amax_epsilon(amax_epsilon);
// Compute amax
NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); });
// Perform amax reduction if needed
if (with_amax_reduction) {
// allreduce amax tensor
c10d::AllreduceOptions opts;
opts.reduceOp = c10d::ReduceOp::MAX;
std::vector<at::Tensor> tensors = {amax};
NVTE_SCOPED_GIL_RELEASE({ amax_reduction_group->allreduce(tensors, opts)->wait(); });
}
// Compute scaling factor
NVTE_SCOPED_GIL_RELEASE({ nvte_compute_scale_from_amax(out.data(), quant_config, stream); });
// Cast to FP8
out.set_amax(nullptr, DType::kFloat32, out.defaultShape); // Avoid atomic amax updates
NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); });
}
Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {
......@@ -280,7 +552,7 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const
}
std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
const std::vector<size_t>& shape, DType dtype) const {
using namespace pybind11::literals;
std::vector<int64_t> torch_shape;
for (auto s : shape) {
......@@ -299,11 +571,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
: Float8BlockScaleTensorFormat::GEMM_READY);
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data_rowwise = std::move(*rowwise_data);
} else {
data_rowwise = at::empty(torch_shape, opts);
}
auto scale_shape = get_scale_shape(shape, false);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
......@@ -373,6 +641,62 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
return {std::move(tensor), std::move(ret)};
}
std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_tensor(
py::object tensor) const {
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>();
// Check the data matches quantizer usages
NVTE_CHECK(!tensor.attr("_rowwise_data").is_none() == rowwise_usage,
"Float8BlockwiseQTensor does not match quantizer usages (has_rowwise_data=",
!tensor.attr("_rowwise_data").is_none(), ", rowwise_usage=", rowwise_usage);
NVTE_CHECK(!tensor.attr("_columnwise_data").is_none() == columnwise_usage,
"Float8BlockwiseQTensor does not match quantizer usages (has_columnwise_data=",
!tensor.attr("_columnwise_data").is_none(), ", columnwise_usage=", columnwise_usage);
auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
if (rowwise_usage) {
const at::Tensor& data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>();
const at::Tensor& scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
void* scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr();
const auto& rowwise_shape = getTensorShape(data_rowwise);
ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape);
const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise);
ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape);
}
if (columnwise_usage) {
const at::Tensor& data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>();
const at::Tensor& scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
void* scale_inv_colwise_dptr = scale_inv_colwise.data_ptr();
const auto& shape = getTensorShape(data_colwise);
ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape);
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape);
}
set_quantization_params(&ret);
return {std::move(ret), std::move(tensor)};
}
void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
if (input.numel() == 0) {
return;
}
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
quant_config.set_force_pow_2_scales(force_pow_2_scales);
quant_config.set_amax_epsilon(amax_epsilon);
if (all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
}
std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size_t>& shape,
bool columnwise) const {
size_t numel = 1;
......@@ -465,71 +789,204 @@ void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype) const {
using namespace pybind11::literals;
std::vector<int64_t> torch_shape;
size_t numel = 1;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
numel *= s;
}
TensorWrapper tensor(NVTE_MXFP8_1D_SCALING);
at::TensorOptions opts;
at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv,
columnwise_scale_inv; // TODO(pgadzinski) - change
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
at::Tensor data;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data = std::move(*rowwise_data);
} else {
data = at::empty(torch_shape, opts);
// Tensor dimensions
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
size_t flat_first_dim = 1;
if (shape.size() > 0) {
for (size_t i = 0; i < shape.size() - 1; ++i) {
flat_first_dim *= shape[i];
}
auto scale_shape = get_scale_shape(shape, false);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
rowwise_scale_inv = at::zeros({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, opts);
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(
rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
}
const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1;
NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE,
" (got shape=", shape, ")");
const auto rowwise_scale_inv_shape = get_scale_shape(shape, false);
const auto columnwise_scale_inv_shape = get_scale_shape(shape, true);
if (columnwise_usage) {
auto scale_shape = get_scale_shape(shape, true);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
columnwise_data = at::empty(torch_shape, opts);
columnwise_scale_inv =
at::zeros({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, opts);
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
tensor.set_columnwise_scale_inv(
columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
// Allocate tensors
at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor;
at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor;
const auto uint8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
if (rowwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(rowwise_scale_inv_shape.begin(),
rowwise_scale_inv_shape.end());
rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts);
rowwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts);
}
this->set_quantization_params(&tensor);
py::object ret;
if (columnwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(columnwise_scale_inv_shape.begin(),
columnwise_scale_inv_shape.end());
columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts);
columnwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts);
}
// Convert tensors to Python
auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object {
return need_cast ? py::cast(tensor) : py::none();
};
auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage);
auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage);
auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage);
auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage);
// Construct Python MXFP8 tensor
py::object out_py;
if (internal) {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorBasePythonClass));
ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data,
"rowwise_scale_inv"_a = rowwise_scale_inv,
"columnwise_scale_inv"_a = columnwise_scale_inv,
out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py,
"columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
} else {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass));
ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype),
"rowwise_data"_a = data, "columnwise_data"_a = columnwise_data,
"rowwise_scale_inv"_a = rowwise_scale_inv,
"columnwise_scale_inv"_a = columnwise_scale_inv,
out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"rowwise_data"_a = rowwise_data_py,
"columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
}
return {std::move(tensor), std::move(ret)};
// Construct C++ MXFP8 tensor
TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING);
if (rowwise_usage) {
out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), this->dtype, shape);
out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0,
rowwise_scale_inv_shape);
}
if (columnwise_usage) {
out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), this->dtype, shape);
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0,
columnwise_scale_inv_shape);
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor.");
// Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name);
if (attr_py.is_none()) {
return std::nullopt;
}
return attr_py.cast<at::Tensor>();
};
auto rowwise_data = get_tensor("_rowwise_data");
auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv");
auto columnwise_data = get_tensor("_columnwise_data");
auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv");
NVTE_CHECK(rowwise_data || columnwise_data, "MXFP8Tensor has no data.");
// Tensor dimensions
std::vector<size_t> shape;
if (columnwise_data) {
shape = getTensorShape(*columnwise_data);
if (rowwise_data) {
auto expected_shape = getTensorShape(*rowwise_data);
NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape,
") and column-wise data (shape=", shape, ") do not match");
}
} else { // Already checked columnwise_data_tensor == true
shape = getTensorShape(*rowwise_data);
}
// Coerce row-wise data
if (rowwise_usage) {
if (!rowwise_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
rowwise_data = at::empty(shape_int64, opts);
tensor.attr("_rowwise_data") = *rowwise_data;
}
if (!rowwise_scale_inv) {
const auto scale_inv_shape = get_scale_shape(shape, false);
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
rowwise_scale_inv = at::zeros(scale_inv_shape_int64, opts);
tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv;
}
} else { // rowwise_usage == false
if (rowwise_data) {
rowwise_data.reset();
tensor.attr("_rowwise_data") = py::none();
}
if (rowwise_scale_inv) {
rowwise_scale_inv.reset();
tensor.attr("_rowwise_scale_inv") = py::none();
}
}
// Coerce column-wise data
if (columnwise_usage) {
if (!columnwise_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
columnwise_data = at::empty(shape_int64, opts);
tensor.attr("_columnwise_data") = *columnwise_data;
}
if (!columnwise_scale_inv) {
const auto scale_inv_shape = get_scale_shape(shape, true);
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
columnwise_scale_inv = at::zeros(scale_inv_shape_int64, opts);
tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv;
}
} else { // columnwise_usage == false
if (columnwise_data) {
columnwise_data.reset();
tensor.attr("_columnwise_data") = py::none();
}
if (columnwise_scale_inv) {
columnwise_scale_inv.reset();
tensor.attr("_columnwise_scale_inv") = py::none();
}
}
// Coerce other attrs
tensor.attr("_fp8_dtype") = dtype;
// Construct C++ MXFP8 tensor
TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING);
if (rowwise_usage) {
out_cpp.set_rowwise_data(rowwise_data->data_ptr(), dtype, shape);
out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0,
getTensorShape(*rowwise_scale_inv));
}
if (columnwise_usage) {
out_cpp.set_columnwise_data(columnwise_data->data_ptr(), dtype, shape);
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0,
getTensorShape(*columnwise_scale_inv));
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
if (input.numel() == 0) {
return;
}
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
}
std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& shape,
......
......@@ -22,8 +22,7 @@ from ...distributed import (
from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
......@@ -480,18 +479,11 @@ class BasicLinear(BasicOperation):
raise ValueError("Output tensor is quantized, but quantizer was not provided")
else:
output_quantizer = None
if isinstance(output_quantizer, MXFP8Quantizer):
raise RuntimeError(
"Attempting to generate MXFP8 output tensor, "
"but GEMM with MXFP8 output is not supported"
)
if isinstance(output_quantizer, Float8BlockQuantizer):
if output_quantizer is not None:
if not isinstance(output_quantizer, Float8Quantizer):
raise RuntimeError(
"Attempting to generate Float8BlockQuantized output tensor, "
"but GEMM with Float8BlockQuantized output is not supported"
"Attempting to generate quantized output tensor with unsupported quantizer"
)
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Check if accumulating into output tensor
......@@ -765,10 +757,11 @@ class BasicLinear(BasicOperation):
)
else:
grad_input_quantizer = None
if isinstance(grad_input_quantizer, MXFP8Quantizer):
if grad_input_quantizer is not None:
if not isinstance(grad_input_quantizer, Float8Quantizer):
raise RuntimeError(
"Attempting to generate MXFP8 grad input tensor, "
"but GEMM with MXFP8 output is not supported"
"Attempting to generate quantized grad input tensor "
"with unsupported quantizer"
)
# Check if accumulating into grad input tensor
......
......@@ -182,7 +182,7 @@ class UserbuffersForwardLinear(FusedOperation):
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
if output_quantizer is not None:
raise ValueError("FP8 output is not supported")
raise ValueError("Quantized output is not supported")
else:
input_quantizer = None
weight_quantizer = None
......
......@@ -59,7 +59,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
......
......@@ -86,7 +86,7 @@ class Float8TensorBase(QuantizedTensorBase):
else:
instance = super().__new__(cls, *args, **kwargs)
instance._data = data
instance._quantizer = quantizer
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._scale_inv = fp8_scale_inv
instance._transpose = data_transpose
......
......@@ -83,7 +83,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
......
......@@ -521,7 +521,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor):
dst._rowwise_data = src._rowwise_data
dst._columnwise_data = src._columnwise_data
dst._quantizer = src._quantizer
dst._quantizer = src._quantizer.copy()
dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv
......
......@@ -108,10 +108,9 @@ class Float8Quantizer(Quantizer):
# Allocate FP8 data transpose if needed
data_transpose = None
if self.columnwise_usage:
inner_dim = data.size(-1)
transpose_shape = [data.size(-1)] + list(data.shape[:-1])
data_transpose = torch.empty(
inner_dim,
data.numel() // inner_dim,
transpose_shape,
dtype=torch.uint8,
device=device,
)
......@@ -230,7 +229,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax_epsilon: float = 0.0,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.scale = torch.ones(1, dtype=torch.float32, device=device)
self.scale = torch.empty(1, dtype=torch.float32, device=device)
self.amax = torch.empty(1, dtype=torch.float32, device=device)
self.dtype = fp8_dtype
self.with_amax_reduction = with_amax_reduction
......@@ -690,7 +689,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# Float8Tensor attributes
self._data = tensor._data
self._quantizer = tensor._quantizer
self._quantizer = tensor._quantizer.copy()
self._fp8_dtype = tensor._fp8_dtype
self._scale_inv = tensor._scale_inv
self._transpose = tensor._transpose
......
......@@ -433,7 +433,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor)
self._rowwise_data = tensor._rowwise_data
self._columnwise_data = tensor._columnwise_data
self._quantizer = tensor._quantizer
self._quantizer = tensor._quantizer.copy()
self._fp8_dtype = tensor._fp8_dtype
self._rowwise_scale_inv = tensor._rowwise_scale_inv
self._columnwise_scale_inv = tensor._columnwise_scale_inv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment