Unverified Commit 7948779c authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] GroupedQuantizer and GroupedScaledTensor (#1666)



* refactor the multi_stream utils + implement nvte_multi_tensor_quantize in TE/Common

* implement GroupedQuantizer and grouped_quantize in jaxx

* fix logical_axes_names for transpose tensor in ScaledTensor
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarMing Huang <mingh@nvidia.com>
parent 9985b02c
......@@ -30,6 +30,7 @@
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/padding.h>
......
......@@ -258,7 +258,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Get cublasLt version", py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>());
m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams);
m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams",
py::call_guard<py::gil_scoped_release>());
// Support THD format for Context Parallel
m.def("thd_read_half_tensor", &transformer_engine::pytorch::thd_read_half_tensor,
......
......@@ -83,7 +83,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_workspace
if not _multi_stream_cublas_workspace:
for _ in range(tex._num_cublas_streams):
for _ in range(tex.get_num_cublas_streams()):
_multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment