Unverified Commit 9239bf71 authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

[Kernel] CUTLASS grouped gemm fp8 MoE kernel (#13972)


Signed-off-by: default avatarElizaWszola <eliza@neuralmagic.com>
Signed-off-by: default avatarElizaWszola <ewszola@redhat.com>
Co-authored-by: default avatarLucas Wilkinson <wilkinson.lucas@gmail.com>
parent 7a6d45bc
...@@ -50,6 +50,16 @@ def cutlass_block_fp8_supported() -> bool: ...@@ -50,6 +50,16 @@ def cutlass_block_fp8_supported() -> bool:
return ops.cutlass_scaled_mm_supports_block_fp8(capability) return ops.cutlass_scaled_mm_supports_block_fp8(capability)
def cutlass_group_gemm_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return ops.cutlass_group_gemm_supported(capability)
CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported() CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported()
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported() CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
......
...@@ -1568,18 +1568,21 @@ class ClassRegistry(UserDict[Type[T], _V]): ...@@ -1568,18 +1568,21 @@ class ClassRegistry(UserDict[Type[T], _V]):
return any(cls in self.data for cls in key.mro()) return any(cls in self.data for cls in key.mro())
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: def weak_ref_tensor(tensor: Any) -> Any:
""" """
Create a weak reference to a tensor. Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor, The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive. but will not keep the original tensor alive.
""" """
return torch.ops._C.weak_ref_tensor(tensor) if isinstance(tensor, torch.Tensor):
return torch.ops._C.weak_ref_tensor(tensor)
else:
return tensor
def weak_ref_tensors( def weak_ref_tensors(
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]: ) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
""" """
Convenience function to create weak references to tensors, Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors. for single tensor, list of tensors or tuple of tensors.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment