Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
......@@ -4,8 +4,8 @@
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "common.h"
#include "extensions.h"
#include "pybind.h"
namespace {
......@@ -24,12 +24,12 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
NVTE_CHECK(fcd_size % block_size == 0, "input size not aligned to block size");
size_t element_size = transformer_engine::pytorch::typeToSize(self.dtype());
size_t element_size_bits = transformer_engine::pytorch::typeToNumBits(self.dtype());
int32_t start_row = start_index.data_ptr<int32_t>()[0];
void *base_ptr = static_cast<char *>(self.get_rowwise_data().data_ptr) +
static_cast<size_t>(start_row) * fcd_size * element_size;
static_cast<size_t>(start_row) * fcd_size * element_size_bits / 8;
size_t num_rows_to_zero = max_tokens - start_row;
size_t total_bytes = num_rows_to_zero * fcd_size * element_size;
size_t total_bytes = num_rows_to_zero * fcd_size * element_size_bits / 8;
NVTE_SCOPED_GIL_RELEASE(
{ nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); });
......@@ -57,17 +57,17 @@ namespace transformer_engine::pytorch {
// get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right) {
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
#ifdef __HIP_PLATFORM_AMD__
return NVTE_Fused_Attn_Backend::NVTE_No_Backend;
#else
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim_qk, head_dim_v, window_size_left, window_size_right);
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q,
max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend;
#endif
}
......
......@@ -6,8 +6,8 @@
#include "transformer_engine/cast.h"
#include "../extensions.h"
#include "common.h"
#include "extensions.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
......@@ -81,6 +81,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
auto my_quantizer_bw = static_cast<Float8BlockQuantizer*>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(te_input.data(), te_output.data(), quant_config,
......
......@@ -216,6 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}
at::Stream CommOverlap::get_communication_stream() {
return at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device());
}
/***************************************************************************************************
* CommOverlapP2P
**************************************************************************************************/
......@@ -300,3 +304,7 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}
at::Stream CommOverlapP2P::get_communication_stream() {
return at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device());
}
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -10,10 +10,10 @@
#include <string>
#include "../common.h"
#include "../extensions.h"
#include "common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "extensions.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../../extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../../extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../../extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../../extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../../extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -4,8 +4,8 @@
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "common/util/system.h"
#include "extensions.h"
#include "pybind.h"
namespace transformer_engine::pytorch {
......@@ -170,6 +170,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
......@@ -328,6 +331,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
#include "pybind.h"
namespace transformer_engine::pytorch {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -261,7 +261,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Get cublasLt version", py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>());
m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams);
m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams",
py::call_guard<py::gil_scoped_release>());
#ifdef USE_ROCM
m.attr("_num_cublas_batchgemm_streams") = py::int_(transformer_engine::num_batchgemm_streams);
#endif
......@@ -390,7 +391,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt);
py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlap::get_communication_stream);
py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>,
transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>(
......@@ -407,5 +409,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt);
py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlapP2P::get_communication_stream);
}
......@@ -9,8 +9,8 @@
#include <string>
#include "common/common.h"
#include "extensions.h"
#include "../extensions.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine::pytorch {
......@@ -34,30 +34,35 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
const std::string& amax_compute_algo,
DType fp8_dtype, float margin) {
size_t num_tensors = amax_histories.size();
std::vector<Tensor> t_amax_histories(num_tensors);
std::vector<Tensor> t_scales(num_tensors);
std::vector<NVTETensor> te_amax_histories(num_tensors);
std::vector<NVTETensor> te_scales(num_tensors);
std::vector<NVTETensor> te_amax_histories;
std::vector<NVTETensor> te_scales;
te_amax_histories.reserve(num_tensors);
te_scales.reserve(num_tensors);
for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories[i].data.dptr = amax_histories[i].data_ptr();
auto amax_sizes = amax_histories[i].sizes().vec();
std::vector<size_t> amax_shape{amax_sizes.begin(), amax_sizes.end()};
t_amax_histories[i].data.shape = amax_shape;
t_amax_histories[i].data.dtype = DType::kFloat32;
t_scales[i].data.dptr = scales[i].data_ptr();
auto scale_sizes = scales[i].sizes().vec();
std::vector<size_t> scale_shape{scale_sizes.begin(), scale_sizes.end()};
t_scales[i].data.shape = scale_shape;
t_scales[i].data.dtype = DType::kFloat32;
te_amax_histories[i] = reinterpret_cast<NVTETensor>(&t_amax_histories[i]);
te_scales[i] = reinterpret_cast<NVTETensor>(&t_scales[i]);
te_amax_histories.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING));
NVTETensor& amax_history = te_amax_histories.back();
NVTEShape amax_shape = convertTorchShape(amax_histories[i].sizes());
NVTEBasicTensor amax_history_data = {amax_histories[i].data_ptr(),
static_cast<NVTEDType>(DType::kFloat32), amax_shape};
nvte_set_tensor_param(&amax_history, kNVTERowwiseData, &amax_history_data);
te_scales.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING));
NVTETensor& scale = te_scales.back();
NVTEShape scale_shape = convertTorchShape(scales[i].sizes());
NVTEBasicTensor scale_data = {scales[i].data_ptr(), static_cast<NVTEDType>(DType::kFloat32),
scale_shape};
nvte_set_tensor_param(&scale, kNVTERowwiseData, &scale_data);
}
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales,
amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin,
at::cuda::getCurrentCUDAStream());
for (auto& t : te_amax_histories) {
nvte_destroy_tensor(t);
}
for (auto& t : te_scales) {
nvte_destroy_tensor(t);
}
}
} // namespace transformer_engine::pytorch
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
namespace transformer_engine::pytorch {
......
......@@ -6,7 +6,7 @@
#include <optional>
#include "extensions.h"
#include "../extensions.h"
#include "pybind.h"
namespace transformer_engine::pytorch {
......
......@@ -261,6 +261,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
"Unsupported block scaling dim.");
this->all_gather_usage = quantizer.attr("all_gather_usage").cast<bool>();
}
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {
......@@ -299,6 +300,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data_rowwise = std::move(*rowwise_data);
......@@ -308,14 +313,24 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(m_dim, 4);
sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor rowwise."
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor rowwise. "
"Expected 1 or 2. Got ",
block_scaling_dim);
}
......@@ -332,6 +347,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ",
columnwise_shape, " torch shape: ", torch_columnwise_shape);
if (torch_shape.size() > 0) {
if (!all_gather_usage) {
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
......@@ -340,18 +356,32 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
}
}
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(k_dim, 4);
sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor columnwise."
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor columnwise. "
"Expected 1 or 2. Got ",
block_scaling_dim);
}
......@@ -373,7 +403,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer,
"is_2D_scaled"_a = (block_scaling_dim == 2));
"is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format);
} else {
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
......@@ -381,7 +411,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise,
"columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise,
"columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2));
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2),
"data_format"_a = data_format);
}
return {std::move(tensor), std::move(ret)};
......
......@@ -8,6 +8,7 @@ from __future__ import annotations
from collections.abc import Iterable
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from functools import lru_cache
from dataclasses import dataclass
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
......@@ -19,6 +20,15 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
try:
import torch.distributed._symmetric_memory as symm_mem
HAS_TORCH_SYMMETRIC = True
except ImportError:
HAS_TORCH_SYMMETRIC = False
import transformer_engine_torch as tex
from . import torch_version
from .utils import (
is_non_tn_fp8_gemm_supported,
......@@ -34,14 +44,8 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer
try:
import torch.distributed._symmetric_memory as symm_mem
HAS_TORCH_SYMMETRIC = True
except ImportError:
HAS_TORCH_SYMMETRIC = False
__all__ = ["checkpoint", "CudaRNGStatesTracker"]
......@@ -943,7 +947,7 @@ def _all_gather_fp8(
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
elif isinstance(inp, Float8Tensor):
out = inp.make_like(inp, shape=out_shape)
out._data = torch.empty_like(
out._data = torch.empty(
out_shape,
dtype=torch.uint8,
device=inp.device,
......@@ -977,6 +981,67 @@ def _all_gather_fp8(
return out, handle
def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
"""Make quantizer compact"""
_quantizer = quantizer
if isinstance(quantizer, DebugQuantizer):
_quantizer = quantizer.parent_quantizer
if isinstance(_quantizer, Float8BlockQuantizer):
_quantizer.all_gather_usage = compact
def _post_process_fp8_blockwise_gather(
out: Float8BlockwiseQTensorBase,
quantizer: Float8BlockQuantizer,
handle: Optional[torch.distributed.Work] = None,
) -> Float8BlockwiseQTensorBase:
"""Post-process FP8 blockwise gather."""
if handle is not None:
handle.wait()
handle = None
if out._is_gemm_ready_format():
return out
needs_columnwise_data_transpose = (
quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported()
)
need_rowwise_scale_transpose = (
quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported()
)
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
# columnwise compact format means doing 128x1 quantization of it
# so quantized tensor is 256x1024, scale inv is 2x1024
# If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization
# on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024
# Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data
if needs_columnwise_data_transpose:
out._transpose_columnwise_data()
if need_rowwise_scale_transpose:
out._rowwise_scale_inv = out._rowwise_scale_inv.transpose(-2, -1).contiguous()
out._data_format = tex.Float8BlockScaleTensorFormat.GEMM_READY
return out
@dataclass
class _FP8BlockwiseAllGatherAsyncHandle:
"""Handle for asynchronous FP8 blockwise all-gather."""
tensor: Float8BlockwiseQTensorBase
quantizer: Float8BlockQuantizer
async_handle: torch.distributed.Work
_synchronized: bool = False
def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
self.async_handle.wait()
_post_process_fp8_blockwise_gather(self.tensor, self.quantizer)
self._synchronized = True
def _all_gather_fp8_blockwise(
inp: torch.Tensor,
process_group: dist_group_type,
......@@ -990,8 +1055,9 @@ def _all_gather_fp8_blockwise(
Returns: quantizer(gather(inp))
NOTE: The implementation is not sophisticated enough to honor async_op=True.
In some cases it falls back to synchronous gather and invokes the quantizer.
NOTE: The implementation is only going to honor async_op=True for FP8 gather case.
In the case where tensor shape is not divisible by 128, the implementation will fall back
to synchronous gather and invoke the quantizer.
"""
# Input tensor attributes
......@@ -1027,7 +1093,11 @@ def _all_gather_fp8_blockwise(
out_shape[0] *= world_size
# Doing BF16 gather for now as baseline because it's simpler
if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None:
if (
not isinstance(inp, Float8BlockwiseQTensorBase)
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
out = torch.empty(
out_shape,
dtype=dtype,
......@@ -1035,14 +1105,93 @@ def _all_gather_fp8_blockwise(
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = False
out = quantizer(out)
quantizer.all_gather_usage = orig_all_gather_usage
return out, None
# Implementation of fp8 gather needs to account for:
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales.
# * Refer to scaffold code when implementing at:
# https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477
raise NotImplementedError("fp8 blockwise allgather not yet implemented")
# Cast input tensor to Float8BlockwiseQTensor with required data
# Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = True
if not isinstance(inp, Float8BlockwiseQTensorBase):
inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
quantizer.columnwise_usage and inp._columnwise_data is None
):
warnings.warn(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to Float8BlockwiseQTensor."
)
inp = quantizer(inp.dequantize())
quantizer.all_gather_usage = orig_all_gather_usage
# Begin to do network communication, need to make sure compact format
if inp._data_format != tex.Float8BlockScaleTensorFormat.COMPACT:
raise RuntimeError(
"All-gather with FP8 block-wise quantized tensor requires compact data format, "
f"but found data_format={inp._data_format}"
)
# Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Coalesce NCCL collectives
with torch.distributed._coalescing_manager(
group=process_group,
device=device,
async_ops=async_op,
) as coalescing_manager:
# Gather Float8BlockwiseQTensor data for row-wise usage
if quantizer.rowwise_usage:
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out._rowwise_scale_inv,
inp._rowwise_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._rowwise_data,
inp._rowwise_data,
group=process_group,
)
# Gather Float8BlockwiseQTensor data for column-wise usage
if quantizer.columnwise_usage:
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out._columnwise_scale_inv,
inp._columnwise_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._columnwise_data,
inp._columnwise_data,
group=process_group,
)
handle = coalescing_manager if async_op else None
# Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper
# This means that we need to transpose the gathered columnwise data
# Example usage is grad_output tensor, ie. dY in linear backward
# We want to gather two FP8 tensors (rowwise and columnwise) along dim0
# and then transpose the columnwise data to match the rowwise data
# Make sure FP8 transpose is populated if needed
if async_op:
handle = _FP8BlockwiseAllGatherAsyncHandle(out, quantizer, handle)
else:
# if it's a sync op, we need to do the transpose here as post processing step
_post_process_fp8_blockwise_gather(out, quantizer, handle)
return out, handle
def _all_gather_mxfp8(
......@@ -1239,12 +1388,18 @@ def gather_along_first_dim(
final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
)
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if isinstance(rowwise, Float8BlockwiseQTensorBase):
rowwise = inp._original_tensor
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]
out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise:
final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
)
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if isinstance(columnwise, Float8BlockwiseQTensorBase):
columnwise = inp._original_tensor
columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise
)
......@@ -1261,6 +1416,9 @@ def gather_along_first_dim(
)
if isinstance(inp, QuantizedTensor):
inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
_set_quantizer_format(quantizer, compact=False)
out = torch.empty(
out_shape,
dtype=inp.dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment