Unverified Commit 05dc1e62 authored by Kevin Tong's avatar Kevin Tong Committed by GitHub
Browse files

NVFP4 Move RHT BLAS to GPU (#2275)



* CUDA RHT
Signed-off-by: default avatarKevin Tong <kevin@augmentcode.com>

* Fix cuda graphs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix bug where RHT mask is tensor instead of int
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarKevin Tong <kevin@augmentcode.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 9dd61922
......@@ -29,7 +29,7 @@ aten = torch.ops.aten
def get_no_random_sign_vector() -> torch.Tensor:
"""Non-random sign vector for Hadamard transform."""
return torch.tensor([1], dtype=torch.float32)
return torch.tensor([1], dtype=torch.float32, device="cuda")
def get_sign_from_vector(vector: torch.Tensor) -> int:
......@@ -41,7 +41,7 @@ def get_sign_from_vector(vector: torch.Tensor) -> int:
mask = 0
for i, v in enumerate(vector):
mask |= (v == -1) << i
return mask
return mask.item()
def get_wgrad_sign_vector() -> torch.Tensor:
......@@ -53,6 +53,7 @@ def get_wgrad_sign_vector() -> torch.Tensor:
return torch.tensor(
[1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1],
dtype=torch.float32,
device="cuda",
)
......@@ -81,6 +82,7 @@ def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
[1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1],
],
dtype=torch.float32,
device="cuda",
)
* hadamard_scale
)
......@@ -94,9 +96,9 @@ def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor:
signs = get_wgrad_sign_vector()
else:
signs = get_no_random_sign_vector()
sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32)
sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32, device="cuda")
rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension)
return rht_matrix.to(dtype=torch.bfloat16).cuda()
return rht_matrix.to(dtype=torch.bfloat16)
@functools.lru_cache(maxsize=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment