Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../extensions.h"
#include "common.h" #include "common.h"
#include "extensions.h"
#include "pybind.h" #include "pybind.h"
namespace { namespace {
...@@ -24,12 +24,12 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s ...@@ -24,12 +24,12 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
NVTE_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); NVTE_CHECK(fcd_size % block_size == 0, "input size not aligned to block size");
size_t element_size = transformer_engine::pytorch::typeToSize(self.dtype()); size_t element_size_bits = transformer_engine::pytorch::typeToNumBits(self.dtype());
int32_t start_row = start_index.data_ptr<int32_t>()[0]; int32_t start_row = start_index.data_ptr<int32_t>()[0];
void *base_ptr = static_cast<char *>(self.get_rowwise_data().data_ptr) + void *base_ptr = static_cast<char *>(self.get_rowwise_data().data_ptr) +
static_cast<size_t>(start_row) * fcd_size * element_size; static_cast<size_t>(start_row) * fcd_size * element_size_bits / 8;
size_t num_rows_to_zero = max_tokens - start_row; size_t num_rows_to_zero = max_tokens - start_row;
size_t total_bytes = num_rows_to_zero * fcd_size * element_size; size_t total_bytes = num_rows_to_zero * fcd_size * element_size_bits / 8;
NVTE_SCOPED_GIL_RELEASE( NVTE_SCOPED_GIL_RELEASE(
{ nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); }); { nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); });
...@@ -57,17 +57,17 @@ namespace transformer_engine::pytorch { ...@@ -57,17 +57,17 @@ namespace transformer_engine::pytorch {
// get the fused attention backend // get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Fused_Attn_Backend get_fused_attn_backend(
const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
int64_t window_size_left, int64_t window_size_right) { size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
return NVTE_Fused_Attn_Backend::NVTE_No_Backend; return NVTE_Fused_Attn_Backend::NVTE_No_Backend;
#else #else
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q,
head_dim_qk, head_dim_v, window_size_left, window_size_right); max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend; return fused_attention_backend;
#endif #endif
} }
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "../extensions.h"
#include "common.h" #include "common.h"
#include "extensions.h"
#include "pybind.h" #include "pybind.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
...@@ -81,6 +81,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob ...@@ -81,6 +81,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
auto my_quantizer_bw = static_cast<Float8BlockQuantizer*>(my_quantizer.get()); auto my_quantizer_bw = static_cast<Float8BlockQuantizer*>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
} }
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, nvte_quantize_v2(te_input.data(), te_output.data(), quant_config,
......
...@@ -216,6 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i ...@@ -216,6 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
} }
at::Stream CommOverlap::get_communication_stream() {
return at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device());
}
/*************************************************************************************************** /***************************************************************************************************
* CommOverlapP2P * CommOverlapP2P
**************************************************************************************************/ **************************************************************************************************/
...@@ -300,3 +304,7 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto ...@@ -300,3 +304,7 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype()); const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
} }
at::Stream CommOverlapP2P::get_communication_stream() {
return at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device());
}
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -10,10 +10,10 @@ ...@@ -10,10 +10,10 @@
#include <string> #include <string>
#include "../common.h" #include "../common.h"
#include "../extensions.h"
#include "common.h" #include "common.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "extensions.h"
#include "pybind.h" #include "pybind.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
#include "util.h" #include "util.h"
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../extensions.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "extensions.h"
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
...@@ -170,6 +170,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -170,6 +170,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get()); auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
} }
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
...@@ -328,6 +331,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -328,6 +331,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get()); auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
} }
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../extensions.h"
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -261,7 +261,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -261,7 +261,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Get cublasLt version", py::call_guard<py::gil_scoped_release>()); "Get cublasLt version", py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version", m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams); m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams",
py::call_guard<py::gil_scoped_release>());
#ifdef USE_ROCM #ifdef USE_ROCM
m.attr("_num_cublas_batchgemm_streams") = py::int_(transformer_engine::num_batchgemm_streams); m.attr("_num_cublas_batchgemm_streams") = py::int_(transformer_engine::num_batchgemm_streams);
#endif #endif
...@@ -390,7 +391,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -390,7 +391,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::arg("local_chunk") = false) py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false, .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt); py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlap::get_communication_stream);
py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>, py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>,
transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>(
...@@ -407,5 +409,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -407,5 +409,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
py::arg("local_chunk") = false) py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt); py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlapP2P::get_communication_stream);
} }
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
#include <string> #include <string>
#include "common/common.h" #include "../extensions.h"
#include "extensions.h" #include "transformer_engine/transformer_engine.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
...@@ -34,30 +34,35 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio ...@@ -34,30 +34,35 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
const std::string& amax_compute_algo, const std::string& amax_compute_algo,
DType fp8_dtype, float margin) { DType fp8_dtype, float margin) {
size_t num_tensors = amax_histories.size(); size_t num_tensors = amax_histories.size();
std::vector<Tensor> t_amax_histories(num_tensors); std::vector<NVTETensor> te_amax_histories;
std::vector<Tensor> t_scales(num_tensors); std::vector<NVTETensor> te_scales;
std::vector<NVTETensor> te_amax_histories(num_tensors); te_amax_histories.reserve(num_tensors);
std::vector<NVTETensor> te_scales(num_tensors); te_scales.reserve(num_tensors);
for (size_t i = 0; i < num_tensors; i++) { for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories[i].data.dptr = amax_histories[i].data_ptr(); te_amax_histories.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING));
auto amax_sizes = amax_histories[i].sizes().vec(); NVTETensor& amax_history = te_amax_histories.back();
std::vector<size_t> amax_shape{amax_sizes.begin(), amax_sizes.end()}; NVTEShape amax_shape = convertTorchShape(amax_histories[i].sizes());
t_amax_histories[i].data.shape = amax_shape; NVTEBasicTensor amax_history_data = {amax_histories[i].data_ptr(),
t_amax_histories[i].data.dtype = DType::kFloat32; static_cast<NVTEDType>(DType::kFloat32), amax_shape};
nvte_set_tensor_param(&amax_history, kNVTERowwiseData, &amax_history_data);
t_scales[i].data.dptr = scales[i].data_ptr();
auto scale_sizes = scales[i].sizes().vec(); te_scales.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING));
std::vector<size_t> scale_shape{scale_sizes.begin(), scale_sizes.end()}; NVTETensor& scale = te_scales.back();
t_scales[i].data.shape = scale_shape; NVTEShape scale_shape = convertTorchShape(scales[i].sizes());
t_scales[i].data.dtype = DType::kFloat32; NVTEBasicTensor scale_data = {scales[i].data_ptr(), static_cast<NVTEDType>(DType::kFloat32),
scale_shape};
te_amax_histories[i] = reinterpret_cast<NVTETensor>(&t_amax_histories[i]); nvte_set_tensor_param(&scale, kNVTERowwiseData, &scale_data);
te_scales[i] = reinterpret_cast<NVTETensor>(&t_scales[i]);
} }
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales, makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales,
amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin, amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
for (auto& t : te_amax_histories) {
nvte_destroy_tensor(t);
}
for (auto& t : te_scales) {
nvte_destroy_tensor(t);
}
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "../extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <optional> #include <optional>
#include "extensions.h" #include "../extensions.h"
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -261,6 +261,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti ...@@ -261,6 +261,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>(); this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
"Unsupported block scaling dim."); "Unsupported block scaling dim.");
this->all_gather_usage = quantizer.attr("all_gather_usage").cast<bool>();
} }
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {
...@@ -299,6 +300,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -299,6 +300,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t m_dim = numel / k_dim; size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128; constexpr size_t kBlockLen = 128;
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
if (rowwise_usage) { if (rowwise_usage) {
if (rowwise_data.has_value()) { if (rowwise_data.has_value()) {
data_rowwise = std::move(*rowwise_data); data_rowwise = std::move(*rowwise_data);
...@@ -308,16 +313,26 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -308,16 +313,26 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t sinv0 = 0; size_t sinv0 = 0;
size_t sinv1 = 0; size_t sinv1 = 0;
if (block_scaling_dim == 2) { if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) { } else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(m_dim, 4); sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
} else { } else {
NVTE_CHECK(false, NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor rowwise." "Unsupported block_scaling_dim in create_tensor rowwise. "
"Expected 1 or 2. Got ", "Expected 1 or 2. Got ",
block_scaling_dim); block_scaling_dim);
} }
scale_inv_rowwise = scale_inv_rowwise =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts); at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
...@@ -332,28 +347,43 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -332,28 +347,43 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ", NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ",
columnwise_shape, " torch shape: ", torch_columnwise_shape); columnwise_shape, " torch shape: ", torch_columnwise_shape);
if (torch_shape.size() > 0) { if (torch_shape.size() > 0) {
torch_columnwise_shape.reserve(torch_shape.size()); if (!all_gather_usage) {
columnwise_shape.reserve(shape.size()); torch_columnwise_shape.reserve(torch_shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); columnwise_shape.reserve(shape.size());
columnwise_shape.push_back(shape[shape.size() - 1]); torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) { columnwise_shape.push_back(shape[shape.size() - 1]);
torch_columnwise_shape.push_back(torch_shape[i]); for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
columnwise_shape.push_back(shape[i]); torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
} }
} }
size_t sinv0 = 0; size_t sinv0 = 0;
size_t sinv1 = 0; size_t sinv1 = 0;
if (block_scaling_dim == 2) { if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4); sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) { } else if (block_scaling_dim == 1) {
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(k_dim, 4); sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
} else { } else {
NVTE_CHECK(false, NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor columnwise." "Unsupported block_scaling_dim in create_tensor columnwise. "
"Expected 1 or 2. Got ", "Expected 1 or 2. Got ",
block_scaling_dim); block_scaling_dim);
} }
data_colwise = at::empty(torch_columnwise_shape, opts); data_colwise = at::empty(torch_columnwise_shape, opts);
scale_inv_colwise = scale_inv_colwise =
...@@ -373,7 +403,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -373,7 +403,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer,
"is_2D_scaled"_a = (block_scaling_dim == 2)); "is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format);
} else { } else {
py::handle Float8BlockwiseQTensorClass( py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass)); reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
...@@ -381,7 +411,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -381,7 +411,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise,
"columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise,
"columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2)); "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2),
"data_format"_a = data_format);
} }
return {std::move(tensor), std::move(ret)}; return {std::move(tensor), std::move(ret)};
......
...@@ -8,6 +8,7 @@ from __future__ import annotations ...@@ -8,6 +8,7 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import contextmanager, AbstractContextManager, ContextDecorator from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from functools import lru_cache from functools import lru_cache
from dataclasses import dataclass
import math import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings import warnings
...@@ -19,6 +20,15 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP ...@@ -19,6 +20,15 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
try:
import torch.distributed._symmetric_memory as symm_mem
HAS_TORCH_SYMMETRIC = True
except ImportError:
HAS_TORCH_SYMMETRIC = False
import transformer_engine_torch as tex
from . import torch_version from . import torch_version
from .utils import ( from .utils import (
is_non_tn_fp8_gemm_supported, is_non_tn_fp8_gemm_supported,
...@@ -34,14 +44,8 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer ...@@ -34,14 +44,8 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer
try:
import torch.distributed._symmetric_memory as symm_mem
HAS_TORCH_SYMMETRIC = True
except ImportError:
HAS_TORCH_SYMMETRIC = False
__all__ = ["checkpoint", "CudaRNGStatesTracker"] __all__ = ["checkpoint", "CudaRNGStatesTracker"]
...@@ -943,7 +947,7 @@ def _all_gather_fp8( ...@@ -943,7 +947,7 @@ def _all_gather_fp8(
out = quantizer.make_empty(out_shape, dtype=dtype, device=device) out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
elif isinstance(inp, Float8Tensor): elif isinstance(inp, Float8Tensor):
out = inp.make_like(inp, shape=out_shape) out = inp.make_like(inp, shape=out_shape)
out._data = torch.empty_like( out._data = torch.empty(
out_shape, out_shape,
dtype=torch.uint8, dtype=torch.uint8,
device=inp.device, device=inp.device,
...@@ -977,6 +981,67 @@ def _all_gather_fp8( ...@@ -977,6 +981,67 @@ def _all_gather_fp8(
return out, handle return out, handle
def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
"""Make quantizer compact"""
_quantizer = quantizer
if isinstance(quantizer, DebugQuantizer):
_quantizer = quantizer.parent_quantizer
if isinstance(_quantizer, Float8BlockQuantizer):
_quantizer.all_gather_usage = compact
def _post_process_fp8_blockwise_gather(
out: Float8BlockwiseQTensorBase,
quantizer: Float8BlockQuantizer,
handle: Optional[torch.distributed.Work] = None,
) -> Float8BlockwiseQTensorBase:
"""Post-process FP8 blockwise gather."""
if handle is not None:
handle.wait()
handle = None
if out._is_gemm_ready_format():
return out
needs_columnwise_data_transpose = (
quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported()
)
need_rowwise_scale_transpose = (
quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported()
)
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
# columnwise compact format means doing 128x1 quantization of it
# so quantized tensor is 256x1024, scale inv is 2x1024
# If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization
# on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024
# Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data
if needs_columnwise_data_transpose:
out._transpose_columnwise_data()
if need_rowwise_scale_transpose:
out._rowwise_scale_inv = out._rowwise_scale_inv.transpose(-2, -1).contiguous()
out._data_format = tex.Float8BlockScaleTensorFormat.GEMM_READY
return out
@dataclass
class _FP8BlockwiseAllGatherAsyncHandle:
"""Handle for asynchronous FP8 blockwise all-gather."""
tensor: Float8BlockwiseQTensorBase
quantizer: Float8BlockQuantizer
async_handle: torch.distributed.Work
_synchronized: bool = False
def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
self.async_handle.wait()
_post_process_fp8_blockwise_gather(self.tensor, self.quantizer)
self._synchronized = True
def _all_gather_fp8_blockwise( def _all_gather_fp8_blockwise(
inp: torch.Tensor, inp: torch.Tensor,
process_group: dist_group_type, process_group: dist_group_type,
...@@ -990,8 +1055,9 @@ def _all_gather_fp8_blockwise( ...@@ -990,8 +1055,9 @@ def _all_gather_fp8_blockwise(
Returns: quantizer(gather(inp)) Returns: quantizer(gather(inp))
NOTE: The implementation is not sophisticated enough to honor async_op=True. NOTE: The implementation is only going to honor async_op=True for FP8 gather case.
In some cases it falls back to synchronous gather and invokes the quantizer. In the case where tensor shape is not divisible by 128, the implementation will fall back
to synchronous gather and invoke the quantizer.
""" """
# Input tensor attributes # Input tensor attributes
...@@ -1027,7 +1093,11 @@ def _all_gather_fp8_blockwise( ...@@ -1027,7 +1093,11 @@ def _all_gather_fp8_blockwise(
out_shape[0] *= world_size out_shape[0] *= world_size
# Doing BF16 gather for now as baseline because it's simpler # Doing BF16 gather for now as baseline because it's simpler
if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None: if (
not isinstance(inp, Float8BlockwiseQTensorBase)
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
out = torch.empty( out = torch.empty(
out_shape, out_shape,
dtype=dtype, dtype=dtype,
...@@ -1035,14 +1105,93 @@ def _all_gather_fp8_blockwise( ...@@ -1035,14 +1105,93 @@ def _all_gather_fp8_blockwise(
memory_format=torch.contiguous_format, memory_format=torch.contiguous_format,
) )
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = False
out = quantizer(out) out = quantizer(out)
quantizer.all_gather_usage = orig_all_gather_usage
return out, None return out, None
# Implementation of fp8 gather needs to account for: # Implementation of fp8 gather needs to account for:
# * Getting columnwise data as a transpose of how it is stored for GEMMS. # * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales. # * Gathering non GEMM swizzled scales.
# * Refer to scaffold code when implementing at:
# https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477 # Cast input tensor to Float8BlockwiseQTensor with required data
raise NotImplementedError("fp8 blockwise allgather not yet implemented") # Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = True
if not isinstance(inp, Float8BlockwiseQTensorBase):
inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
quantizer.columnwise_usage and inp._columnwise_data is None
):
warnings.warn(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to Float8BlockwiseQTensor."
)
inp = quantizer(inp.dequantize())
quantizer.all_gather_usage = orig_all_gather_usage
# Begin to do network communication, need to make sure compact format
if inp._data_format != tex.Float8BlockScaleTensorFormat.COMPACT:
raise RuntimeError(
"All-gather with FP8 block-wise quantized tensor requires compact data format, "
f"but found data_format={inp._data_format}"
)
# Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Coalesce NCCL collectives
with torch.distributed._coalescing_manager(
group=process_group,
device=device,
async_ops=async_op,
) as coalescing_manager:
# Gather Float8BlockwiseQTensor data for row-wise usage
if quantizer.rowwise_usage:
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out._rowwise_scale_inv,
inp._rowwise_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._rowwise_data,
inp._rowwise_data,
group=process_group,
)
# Gather Float8BlockwiseQTensor data for column-wise usage
if quantizer.columnwise_usage:
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out._columnwise_scale_inv,
inp._columnwise_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._columnwise_data,
inp._columnwise_data,
group=process_group,
)
handle = coalescing_manager if async_op else None
# Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper
# This means that we need to transpose the gathered columnwise data
# Example usage is grad_output tensor, ie. dY in linear backward
# We want to gather two FP8 tensors (rowwise and columnwise) along dim0
# and then transpose the columnwise data to match the rowwise data
# Make sure FP8 transpose is populated if needed
if async_op:
handle = _FP8BlockwiseAllGatherAsyncHandle(out, quantizer, handle)
else:
# if it's a sync op, we need to do the transpose here as post processing step
_post_process_fp8_blockwise_gather(out, quantizer, handle)
return out, handle
def _all_gather_mxfp8( def _all_gather_mxfp8(
...@@ -1239,12 +1388,18 @@ def gather_along_first_dim( ...@@ -1239,12 +1388,18 @@ def gather_along_first_dim(
final_quantizer = ( final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
) )
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if isinstance(rowwise, Float8BlockwiseQTensorBase):
rowwise = inp._original_tensor
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0] rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]
out_obj.rowwise_gemm_tensor = rowwise_total out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise: if rowwise is not columnwise:
final_quantizer_columnwise = ( final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
) )
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if isinstance(columnwise, Float8BlockwiseQTensorBase):
columnwise = inp._original_tensor
columnwise_total, _ = gather_along_first_dim( columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise columnwise, process_group, False, final_quantizer_columnwise
) )
...@@ -1261,6 +1416,9 @@ def gather_along_first_dim( ...@@ -1261,6 +1416,9 @@ def gather_along_first_dim(
) )
if isinstance(inp, QuantizedTensor): if isinstance(inp, QuantizedTensor):
inp = inp.dequantize() inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
_set_quantizer_format(quantizer, compact=False)
out = torch.empty( out = torch.empty(
out_shape, out_shape,
dtype=inp.dtype, dtype=inp.dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment