Unverified Commit 8ef8285c authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Minor improvements to CPU overhead (#2400)



* Minor CPU overhead changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cache per device
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 49f7c1db
...@@ -717,11 +717,16 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { ...@@ -717,11 +717,16 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
} }
int nvte_is_non_tn_fp8_gemm_supported() { int nvte_is_non_tn_fp8_gemm_supported() {
int deviceComputeCapability = int num_devices = transformer_engine::cuda::num_devices();
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()); static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
// Note: this is temporary restriction and should be lifted in the future. int device_id = transformer_engine::cuda::current_device();
// (remove the note once it's done.) std::call_once(flags[device_id], [&]() {
return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) || int deviceComputeCapability = transformer_engine::cuda::sm_arch(device_id);
deviceComputeCapability >= 130; // Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
cache[device_id] = (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
});
return cache[device_id];
} }
...@@ -12,10 +12,7 @@ import transformer_engine_torch as tex ...@@ -12,10 +12,7 @@ import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor from ..utils import get_sm_count, _empty_tensor
from ..quantized_tensor import Quantizer, QuantizedTensor, QuantizedTensorStorage from ..quantized_tensor import Quantizer
from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.utils import is_custom from ..tensor.utils import is_custom
from ..custom_recipes.gemm import custom_gemm from ..custom_recipes.gemm import custom_gemm
...@@ -46,8 +43,10 @@ def get_cublas_workspace(device: int, ub: bool, grouped_gemm: bool) -> torch.Ten ...@@ -46,8 +43,10 @@ def get_cublas_workspace(device: int, ub: bool, grouped_gemm: bool) -> torch.Ten
if ub: if ub:
return torch.empty( return torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS,
).repeat(_NUM_MAX_UB_STREAMS) dtype=torch.uint8,
device=device,
)
if grouped_gemm: if grouped_gemm:
_multi_stream_cublas_workspace = [] _multi_stream_cublas_workspace = []
for _ in range(tex.get_num_cublas_streams()): for _ in range(tex.get_num_cublas_streams()):
...@@ -69,29 +68,25 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: ...@@ -69,29 +68,25 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
def get_tensor_device(tensor: torch.Tensor) -> int: def get_tensor_device(tensor: torch.Tensor) -> int:
"""Returns tensor device as an integer""" """
if not isinstance(tensor, QuantizedTensorStorage): Returns tensor device as an integer.
return tensor.device.index
if isinstance(tensor, QuantizedTensor): This method is used because checking instances of
QuantizedTensor or Storage incurs more CPU overhead.
The order of attributes checked is important to also
minimize overhead.
"""
if hasattr(tensor, "device"):
return tensor.device.index return tensor.device.index
if isinstance(tensor, (Float8BlockwiseQTensorStorage, MXFP8TensorStorage, NVFP4TensorStorage)): if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return ( return tensor._rowwise_data.device.index
tensor._rowwise_data.device.index if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
if tensor._rowwise_data is not None return tensor._columnwise_data.device.index
else tensor._columnwise_data.device.index if hasattr(tensor, "_data") and tensor._data is not None:
) return tensor._data.device.index
if isinstance(tensor, Float8TensorStorage): if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return ( return tensor._transpose.device.index
tensor._data.device.index return torch.cuda.current_device()
if tensor._data is not None
else tensor._transpose.device.index
)
try:
return (
tensor._data.device.index if tensor._data is not None else tensor._data_t.device.index
)
except AttributeError:
return torch.cuda.current_device()
def general_gemm( def general_gemm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment