Commit 113ee450 authored by zhanghj2's avatar zhanghj2
Browse files

fix k_scale 未定义

parent 702e8c22
......@@ -91,7 +91,7 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
# print("tile_scheduler_metadata:", tile_scheduler_metadata.shape, tile_scheduler_metadata)
# torch.set_printoptions(precision=4, profile="default", sci_mode=False)
# print("num_splits:", num_splits.shape, num_splits)
# k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0")
k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0")
# k_scale = torch.tensor(2.1).to(torch.float32).to("cuda:0")
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment