Commit 783bd930 authored by zhanggzh's avatar zhanggzh
Browse files

add nosupport K100 FA code

parent c0dbff40
...@@ -171,7 +171,7 @@ def local_chat( ...@@ -171,7 +171,7 @@ def local_chat(
"please change max_seq_len in ~/.ktransformers/config.yaml" "please change max_seq_len in ~/.ktransformers/config.yaml"
#if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8: #if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled and (get_compute_capability() >= 8 or ("Z100" in get_device_name()) or ("Z100L" in get_device_name())): if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled and (get_compute_capability() >= 8 or ("Z100" in get_device_name()) or ("Z100L" in get_device_name())) or ("K100" in get_device_name()):
generated = prefill_and_generate( generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
......
...@@ -592,8 +592,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -592,8 +592,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
#if os.name == 'nt' or get_compute_capability()<8: #if os.name == 'nt' or get_compute_capability()<8:
#print("for Windows or GPU before ampere, use forward_windows") #print("for Windows or GPU before ampere, use forward_windows")
if os.name == 'nt' or get_compute_capability()<8 or ("Z100" in get_device_name()) or ("Z100L" in get_device_name()): if os.name == 'nt' or get_compute_capability()<8 or ("Z100" in get_device_name()) or ("Z100L" in get_device_name()) or ("K100" in get_device_name()):
print("for Windows or GPU before ampere or Z100/Z100L, use forward_windows") print("for Windows or GPU before ampere or Z100/Z100L or K100, use forward_windows")
return self.forward_windows( return self.forward_windows(
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -652,8 +652,8 @@ class KDeepseekV2Model(BaseInjectedModule): ...@@ -652,8 +652,8 @@ class KDeepseekV2Model(BaseInjectedModule):
else: else:
#if os.name == 'nt' or get_compute_capability()<8: #if os.name == 'nt' or get_compute_capability()<8:
# print("for Windows or GPU before ampere, use forward_windows") # print("for Windows or GPU before ampere, use forward_windows")
if os.name == 'nt' or get_compute_capability()<8 or ("Z100" in get_device_name()) or ("Z100L" in get_device_name()): if os.name == 'nt' or get_compute_capability()<8 or ("Z100" in get_device_name()) or ("Z100L" in get_device_name()) or ("K100" in get_device_name()):
print("for Windows or GPU before ampere or Z100/Z100L, use forward_windows") print("for Windows or GPU before ampere or Z100/Z100L or K100, use forward_windows")
# only use mask in forward windows or can't flash attn # only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment