Commit 422af727 authored by zhuwenwen's avatar zhuwenwen
Browse files

restore qwen-vl to default layout

parent 6d4b207a
...@@ -22,10 +22,14 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -22,10 +22,14 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture( def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visual = getattr(model_config.hf_config, "visual", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM'] support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1' if architectures == ['QWenLMHeadModel'] and visual != []:
os.environ['LLAMA_NN'] = '0'
else:
os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1': if os.getenv('FA_PAD') != '1':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment