Unverified Commit 8ef8285c authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Minor improvements to CPU overhead (#2400)



* Minor CPU overhead changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cache per device
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 49f7c1db
......@@ -717,11 +717,16 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}
int nvte_is_non_tn_fp8_gemm_supported() {
int deviceComputeCapability =
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
std::call_once(flags[device_id], [&]() {
int deviceComputeCapability = transformer_engine::cuda::sm_arch(device_id);
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
cache[device_id] = (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
});
return cache[device_id];
}
......@@ -12,10 +12,7 @@ import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor
from ..quantized_tensor import Quantizer, QuantizedTensor, QuantizedTensorStorage
from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..quantized_tensor import Quantizer
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.utils import is_custom
from ..custom_recipes.gemm import custom_gemm
......@@ -46,8 +43,10 @@ def get_cublas_workspace(device: int, ub: bool, grouped_gemm: bool) -> torch.Ten
if ub:
return torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device
).repeat(_NUM_MAX_UB_STREAMS)
get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS,
dtype=torch.uint8,
device=device,
)
if grouped_gemm:
_multi_stream_cublas_workspace = []
for _ in range(tex.get_num_cublas_streams()):
......@@ -69,29 +68,25 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
def get_tensor_device(tensor: torch.Tensor) -> int:
"""Returns tensor device as an integer"""
if not isinstance(tensor, QuantizedTensorStorage):
return tensor.device.index
if isinstance(tensor, QuantizedTensor):
"""
Returns tensor device as an integer.
This method is used because checking instances of
QuantizedTensor or Storage incurs more CPU overhead.
The order of attributes checked is important to also
minimize overhead.
"""
if hasattr(tensor, "device"):
return tensor.device.index
if isinstance(tensor, (Float8BlockwiseQTensorStorage, MXFP8TensorStorage, NVFP4TensorStorage)):
return (
tensor._rowwise_data.device.index
if tensor._rowwise_data is not None
else tensor._columnwise_data.device.index
)
if isinstance(tensor, Float8TensorStorage):
return (
tensor._data.device.index
if tensor._data is not None
else tensor._transpose.device.index
)
try:
return (
tensor._data.device.index if tensor._data is not None else tensor._data_t.device.index
)
except AttributeError:
return torch.cuda.current_device()
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
return tensor._columnwise_data.device.index
if hasattr(tensor, "_data") and tensor._data is not None:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
return torch.cuda.current_device()
def general_gemm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment