Unverified Commit 898285c9 authored by Kyujin Cho's avatar Kyujin Cho Committed by GitHub
Browse files

fix: CUDA error when inferencing with Falcon-40B base model (#992)

parent a62de9ec
...@@ -114,8 +114,9 @@ class ModelConfig: ...@@ -114,8 +114,9 @@ class ModelConfig:
# Note: for falcon, when new_decoder_architecture is True, the # Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of # multi_query flag is ignored and we use n_head_kv for the number of
# KV heads. # KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = ( new_decoder_arch_falcon = (
self.hf_config.model_type == "falcon" self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)) and getattr(self.hf_config, "new_decoder_architecture", False))
if not new_decoder_arch_falcon and getattr(self.hf_config, if not new_decoder_arch_falcon and getattr(self.hf_config,
"multi_query", False): "multi_query", False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment