Commit 0bc22e1d authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from vary.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
# from vary.train.train import train
from vary.train.train_lora import train
if __name__ == "__main__":
train()
This diff is collapsed.
This diff is collapsed.
import os
import torch
import torch.nn as nn
from transformers import Trainer
from typing import Dict, Optional, Sequence
def unwrap_model(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model
class varyTrainer(Trainer):
def _safe_save(self, output_dir: str):
"""Collects the state dict and dump to disk."""
if self.deepspeed:
torch.cuda.synchronize()
self.save_model(output_dir)
return
state_dict = self.model.state_dict()
if self.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
self._save(output_dir, state_dict=cpu_state_dict) # noqa
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
# Save the model
_state_dict = state_dict
if _state_dict is None:
# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self.model)
_state_dict = model_to_save.state_dict()
weight_to_save = {}
keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
for k, v in _state_dict.items():
if any(key_match in k for key_match in keys_to_match):
weight_to_save[k] = v
current_folder = output_dir.split('/')[-1]
parent_folder = os.path.dirname(output_dir)
if current_folder.startswith('checkpoint-'):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
else:
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
super(varyTrainer, self)._save(output_dir, state_dict)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment