"docs/static/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "a4d8a4ea6b7869d7fd5b7d6bda057752204d603e"
Unverified Commit 63bf1609 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Fix fp8 dtype for some cases (#1246)

* [Enhancement] Add FP8 support and reproducibility in lighting indexer

* Introduced a manual seed in `test_fp8_lighting_indexer` to ensure reproducible performance.
* Added specializations for `cute::float_e4m3_t` and `cute::float_e5m2_t` in `gemm_mma.h` for enhanced FP8 support across multiple CUDA architectures, ensuring compatibility and improved functionality.ix

* Fix typos in `fp8_lighting_indexer.py` and improve formatting in `gemm_mma.h`

* Corrected a typo in the comment for `test_fp8_lighting_indexer` to enhance clarity.
* Reformatted lines in `gemm_mma.h` for better readability by aligning template specializations across multiple CUDA architectures.

* test fix

* bug fix
parent f550a58d
...@@ -258,6 +258,8 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, ...@@ -258,6 +258,8 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
# initial random seed to make the performance reproducible
torch.manual_seed(0)
q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
weights = torch.randn(S, H, device="cuda", dtype=torch.float32) weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
......
...@@ -273,8 +273,8 @@ public: ...@@ -273,8 +273,8 @@ public:
tfloat32_t, B_type_cute>::type; tfloat32_t, B_type_cute>::type;
using C_type = C_type_raw; using C_type = C_type_raw;
using Instruction = using Instruction = DispatchInstruction<A_type_raw, B_type_raw, C_type_raw,
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>; num_warp_m, num_warp_n, N>;
using OperandATraits = OperandTraits<sizeof_bits<A_type>::value, M, K, using OperandATraits = OperandTraits<sizeof_bits<A_type>::value, M, K,
!trans_A, num_warp_m, lda>; !trans_A, num_warp_m, lda>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment