Unverified Commit bb76538b authored by Shashwat Srijan's avatar Shashwat Srijan Committed by GitHub
Browse files

[Hardwware][Neuron] Simplify model load for transformers-neuronx library (#9380)

parent d615b5c9
...@@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Tuple ...@@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
...@@ -108,39 +107,11 @@ class NeuronCasualLM(nn.Module): ...@@ -108,39 +107,11 @@ class NeuronCasualLM(nn.Module):
neuronx_module = importlib.import_module(neuronx_module_path) neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
split_model_dir = f"{model_name_or_path}-split" self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
if _is_pretrained_neuron_checkpoint(model_name_or_path):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
hf_model_cls = getattr(transformers, hf_model_cls_name)
from transformers_neuronx.module import save_pretrained_split
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = neuronx_model_cls.from_pretrained(split_model_dir,
**kwargs) **kwargs)
self.model.to_neuron() self.model.to_neuron()
def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool:
# Checking if the neuron checkpoint is saved in the old format.
if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")):
return True
# Checking if the neuron checkpoint is saved in the new format.
pretrained_split_files = ["config.json", "generation_config.json"]
pretrained_split_format = ".safetensors"
for file in pretrained_split_files:
file_path = os.path.join(model_name_or_path, file)
if not os.path.isfile(file_path):
return False
for file in os.listdir(model_name_or_path):
if file.endswith(pretrained_split_format):
return True
return False
def _get_model_architecture(config: PretrainedConfig) -> str: def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment