Unverified Commit 069e490b authored by elvischenv's avatar elvischenv Committed by GitHub
Browse files

feat: support trtllm_mha FP8 query attention kernel (#12307)

parent ab95d35f
...@@ -529,6 +529,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -529,6 +529,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) )
if self.data_type == torch.float8_e4m3fn:
q = q.to(torch.float8_e4m3fn)
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
# shape conversion: # shape conversion:
...@@ -567,6 +569,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -567,6 +569,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
window_left=layer.sliding_window_size, window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed # TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink, sinks=attention_sink,
out_dtype=self.q_data_type, # model_runner.dtype
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
...@@ -586,6 +589,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -586,6 +589,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) )
if self.data_type == torch.float8_e4m3fn:
q = q.to(torch.float8_e4m3fn)
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
...@@ -625,6 +631,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -625,6 +631,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
window_left=layer.sliding_window_size, window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed # TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink, sinks=attention_sink,
out_dtype=self.q_data_type, # model_runner.dtype
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment