Unverified Commit a12061df authored by Ben Barsdell's avatar Ben Barsdell Committed by GitHub
Browse files

Fix cuda graph mode in flashinfer attn backend (#10056)

parent 85ed8e0a
...@@ -501,8 +501,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -501,8 +501,9 @@ class FlashInferAttnBackend(AttentionBackend):
sm_scale=layer.scaling, sm_scale=layer.scaling,
window_left=layer.sliding_window_size, window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale, # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
v_scale=layer.v_scale, k_scale=layer.k_scale_float,
v_scale=layer.v_scale_float,
) )
else: else:
causal = True causal = True
...@@ -580,8 +581,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -580,8 +581,9 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap, logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale, # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
v_scale=layer.v_scale, k_scale=layer.k_scale_float,
v_scale=layer.v_scale_float,
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment