Commit c873301f authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
This diff is collapsed.
import os
import torch
import torch.nn as nn
from transformers import Trainer
from typing import Dict, Optional, Sequence
def unwrap_model(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model
class varyTrainer(Trainer):
def _safe_save(self, output_dir: str):
"""Collects the state dict and dump to disk."""
if self.deepspeed:
torch.cuda.synchronize()
self.save_model(output_dir)
return
state_dict = self.model.state_dict()
if self.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
self._save(output_dir, state_dict=cpu_state_dict) # noqa
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
# Save the model
_state_dict = state_dict
if _state_dict is None:
# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self.model)
_state_dict = model_to_save.state_dict()
weight_to_save = {}
keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
for k, v in _state_dict.items():
if any(key_match in k for key_match in keys_to_match):
weight_to_save[k] = v
current_folder = output_dir.split('/')[-1]
parent_folder = os.path.dirname(output_dir)
if current_folder.startswith('checkpoint-'):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
else:
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
super(varyTrainer, self)._save(output_dir, state_dict)
This diff is collapsed.
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import transformers
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
use_cache: bool = field(default=False)
vision_tower: Optional[str] = field(default="~/.cache/huggingface/hub/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff/")
freeze_vision_tower: bool = field(default=False)
freeze_lm_model: bool = field(default=False)
pretrained_stage1_model: Optional[str] = field(default=None) # mlp &/ vision tower
vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
use_im_start_end: bool = field(default=False)
@dataclass
class DataArguments:
datasets: str = field(default=None, metadata={"help": "combinations of the training data."})
sep_image_conv_front: bool = False
image_token_len: int = 256
image_aspect_ratio: str = 'square'
conversation_version: str = 'mpt'
# conversation_version: str = 'v0'
# conversation_version: str = 'v1'
# conversation_version: str = 'opt'
box_limit: int = 0
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
force_fsdp: bool = field(default=False)
interleave: bool = field(default=False)
with_box: bool = field(default=False)
model_max_length: int = field(
default=512,
metadata={
"help":
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
lora_enable: bool = False
lora_r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
{
"bf16": {
"enabled": true
},
"train_micro_batch_size_per_gpu": "auto",
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment