Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
suily
VTimeLLM_pytorch
Commits
fef630ee
Commit
fef630ee
authored
Nov 20, 2024
by
suily
Browse files
init
parents
Pipeline
#1942
failed with stages
in 0 seconds
Changes
65
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
710 additions
and
0 deletions
+710
-0
vtimellm/train/llama_flash_attn_monkey_patch.py
vtimellm/train/llama_flash_attn_monkey_patch.py
+124
-0
vtimellm/train/train.py
vtimellm/train/train.py
+385
-0
vtimellm/train/train_mem.py
vtimellm/train/train_mem.py
+20
-0
vtimellm/train/vtimellm_trainer.py
vtimellm/train/vtimellm_trainer.py
+55
-0
vtimellm/utils.py
vtimellm/utils.py
+126
-0
No files found.
vtimellm/train/llama_flash_attn_monkey_patch.py
0 → 100755
View file @
fef630ee
from
typing
import
List
,
Optional
,
Tuple
import
logging
import
torch
from
torch
import
nn
import
transformers
from
transformers.models.llama.modeling_llama
import
apply_rotary_pos_emb
from
einops
import
rearrange
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_func
except
ImportError
:
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_qkvpacked_func
as
flash_attn_unpadded_qkvpacked_func
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
=
(
self
.
q_proj
(
hidden_states
)
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
)
key_states
=
(
self
.
k_proj
(
hidden_states
)
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
)
value_states
=
(
self
.
v_proj
(
hidden_states
)
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len
=
key_states
.
shape
[
-
2
]
assert
past_key_value
is
None
,
"past_key_value is not supported"
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
# [bsz, nh, t, hd]
assert
not
output_attentions
,
"output_attentions is not supported"
assert
not
use_cache
,
"use_cache is not supported"
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv
=
torch
.
stack
(
[
query_states
,
key_states
,
value_states
],
dim
=
2
)
# [bsz, nh, 3, q_len, hd]
qkv
=
qkv
.
transpose
(
1
,
3
)
# [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask
=
attention_mask
if
key_padding_mask
is
None
:
qkv
=
rearrange
(
qkv
,
"b s ... -> (b s) ..."
)
max_s
=
q_len
cu_q_lens
=
torch
.
arange
(
0
,
(
bsz
+
1
)
*
q_len
,
step
=
q_len
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
output
=
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_q_lens
,
max_s
,
0.0
,
softmax_scale
=
None
,
causal
=
True
)
output
=
rearrange
(
output
,
"(b s) ... -> b s ..."
,
b
=
bsz
)
else
:
nheads
=
qkv
.
shape
[
-
2
]
x
=
rearrange
(
qkv
,
"b s three h d -> b s (three h d)"
)
x_unpad
,
indices
,
cu_q_lens
,
max_s
=
unpad_input
(
x
,
key_padding_mask
)
x_unpad
=
rearrange
(
x_unpad
,
"nnz (three h d) -> nnz three h d"
,
three
=
3
,
h
=
nheads
)
output_unpad
=
flash_attn_unpadded_qkvpacked_func
(
x_unpad
,
cu_q_lens
,
max_s
,
0.0
,
softmax_scale
=
None
,
causal
=
True
)
output
=
rearrange
(
pad_input
(
rearrange
(
output_unpad
,
"nnz h d -> nnz (h d)"
),
indices
,
bsz
,
q_len
),
"b s (h d) -> b s h d"
,
h
=
nheads
,
)
return
self
.
o_proj
(
rearrange
(
output
,
"b s h d -> b s (h d)"
)),
None
,
None
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def
_prepare_decoder_attention_mask
(
self
,
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
):
# [bsz, seq_len]
return
attention_mask
def
replace_llama_attn_with_flash_attn
():
cuda_major
,
cuda_minor
=
torch
.
cuda
.
get_device_capability
()
if
cuda_major
<
8
:
logging
.
warning
(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
transformers
.
models
.
llama
.
modeling_llama
.
LlamaModel
.
_prepare_decoder_attention_mask
=
(
_prepare_decoder_attention_mask
)
transformers
.
models
.
llama
.
modeling_llama
.
LlamaAttention
.
forward
=
forward
vtimellm/train/train.py
0 → 100755
View file @
fef630ee
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
root_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
".."
,
".."
)
from
dataclasses
import
dataclass
,
field
import
logging
import
pathlib
from
typing
import
Dict
,
Optional
,
Sequence
,
List
import
torch
import
transformers
import
sys
sys
.
path
.
append
(
root_dir
)
from
vtimellm
import
conversation
as
conversation_lib
from
vtimellm.train.vtimellm_trainer
import
VTimeLLMTrainer
from
vtimellm.train.dataset
import
make_supervised_data_module
,
DataArguments
from
vtimellm.model
import
VTimeLLMLlamaForCausalLM
,
VTimeLLMChatGLMForCausalLM
from
vtimellm.model.builder
import
load_lora
from
vtimellm.mm_utils
import
print_trainable_parameters
local_rank
=
None
def
rank0_print
(
*
args
):
if
local_rank
==
0
:
print
(
*
args
)
@
dataclass
class
ModelArguments
:
model_name_or_path
:
Optional
[
str
]
=
field
(
default
=
"lmsys/vicuna-7b-v1.5"
)
stage2_path
:
Optional
[
str
]
=
field
(
default
=
None
)
version
:
Optional
[
str
]
=
field
(
default
=
"v0"
)
tune_mm_mlp_adapter
:
bool
=
field
(
default
=
False
)
pretrain_mm_mlp_adapter
:
Optional
[
str
]
=
field
(
default
=
None
)
@
dataclass
class
TrainingArguments
(
transformers
.
TrainingArguments
):
training_stage
:
int
=
field
(
default
=
2
)
cache_dir
:
Optional
[
str
]
=
field
(
default
=
None
)
optim
:
str
=
field
(
default
=
"adamw_torch"
)
remove_unused_columns
:
bool
=
field
(
default
=
False
)
freeze_mm_mlp_adapter
:
bool
=
field
(
default
=
False
)
model_max_length
:
int
=
field
(
default
=
512
,
metadata
=
{
"help"
:
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
double_quant
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Compress the quantization statistics through double quantization."
}
)
quant_type
:
str
=
field
(
default
=
"nf4"
,
metadata
=
{
"help"
:
"Quantization data type to use. Should be one of `fp4` or `nf4`."
}
)
bits
:
int
=
field
(
default
=
16
,
metadata
=
{
"help"
:
"How many bits to use."
}
)
lora_enable
:
bool
=
False
lora_r
:
int
=
64
lora_alpha
:
int
=
16
lora_dropout
:
float
=
0.05
lora_weight_path
:
str
=
""
lora_bias
:
str
=
"none"
def
maybe_zero_3
(
param
,
ignore_status
=
False
,
name
=
None
):
from
deepspeed
import
zero
from
deepspeed.runtime.zero.partition_parameters
import
ZeroParamStatus
if
hasattr
(
param
,
"ds_id"
):
if
param
.
ds_status
==
ZeroParamStatus
.
NOT_AVAILABLE
:
if
not
ignore_status
:
logging
.
warning
(
f
"
{
name
}
: param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
{
param
.
ds_status
}
"
)
with
zero
.
GatheredParameters
([
param
]):
param
=
param
.
data
.
detach
().
cpu
().
clone
()
else
:
param
=
param
.
detach
().
cpu
().
clone
()
return
param
# Borrowed from peft.utils.get_peft_model_state_dict
def
get_peft_state_maybe_zero_3
(
named_params
,
bias
):
if
bias
==
"none"
:
to_return
=
{
k
:
t
for
k
,
t
in
named_params
if
"lora_"
in
k
}
elif
bias
==
"all"
:
to_return
=
{
k
:
t
for
k
,
t
in
named_params
if
"lora_"
in
k
or
"bias"
in
k
}
elif
bias
==
"lora_only"
:
to_return
=
{}
maybe_lora_bias
=
{}
lora_bias_names
=
set
()
for
k
,
t
in
named_params
:
if
"lora_"
in
k
:
to_return
[
k
]
=
t
bias_name
=
k
.
split
(
"lora_"
)[
0
]
+
"bias"
lora_bias_names
.
add
(
bias_name
)
elif
"bias"
in
k
:
maybe_lora_bias
[
k
]
=
t
for
k
,
t
in
maybe_lora_bias
:
if
bias_name
in
lora_bias_names
:
to_return
[
bias_name
]
=
t
else
:
raise
NotImplementedError
to_return
=
{
k
:
maybe_zero_3
(
v
,
name
=
k
)
for
k
,
v
in
to_return
.
items
()}
return
to_return
def
get_peft_state_non_lora_maybe_zero_3
(
named_params
,
require_grad_only
=
True
):
to_return
=
{
k
:
t
for
k
,
t
in
named_params
if
"lora_"
not
in
k
}
if
require_grad_only
:
to_return
=
{
k
:
t
for
k
,
t
in
to_return
.
items
()
if
t
.
requires_grad
}
to_return
=
{
k
:
maybe_zero_3
(
v
,
ignore_status
=
True
).
cpu
()
for
k
,
v
in
to_return
.
items
()}
return
to_return
def
get_mm_adapter_state_maybe_zero_3
(
named_params
,
keys_to_match
):
to_return
=
{
k
:
t
for
k
,
t
in
named_params
if
any
(
key_match
in
k
for
key_match
in
keys_to_match
)}
to_return
=
{
k
:
maybe_zero_3
(
v
,
ignore_status
=
True
).
cpu
()
for
k
,
v
in
to_return
.
items
()}
return
to_return
def
find_all_linear_names
(
model
):
cls
=
torch
.
nn
.
Linear
lora_module_names
=
set
()
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
cls
):
names
=
name
.
split
(
'.'
)
lora_module_names
.
add
(
names
[
0
]
if
len
(
names
)
==
1
else
names
[
-
1
])
if
'lm_head'
in
lora_module_names
:
# needed for 16-bit
lora_module_names
.
remove
(
'lm_head'
)
return
list
(
lora_module_names
)
def
safe_save_model_for_hf_trainer
(
trainer
:
transformers
.
Trainer
,
output_dir
:
str
):
"""Collects the state dict and dump to disk."""
if
getattr
(
trainer
.
args
,
"tune_mm_mlp_adapter"
,
False
):
# Only save Adapter
keys_to_match
=
[
'mm_projector'
]
if
getattr
(
trainer
.
args
,
"use_im_start_end"
,
False
):
keys_to_match
.
extend
([
'embed_tokens'
,
'embed_in'
])
weight_to_save
=
get_mm_adapter_state_maybe_zero_3
(
trainer
.
model
.
named_parameters
(),
keys_to_match
)
trainer
.
model
.
config
.
save_pretrained
(
output_dir
)
current_folder
=
output_dir
.
split
(
'/'
)[
-
1
]
parent_folder
=
os
.
path
.
dirname
(
output_dir
)
if
trainer
.
args
.
local_rank
==
0
or
trainer
.
args
.
local_rank
==
-
1
:
if
current_folder
.
startswith
(
'checkpoint-'
):
mm_projector_folder
=
os
.
path
.
join
(
parent_folder
,
"mm_projector"
)
os
.
makedirs
(
mm_projector_folder
,
exist_ok
=
True
)
torch
.
save
(
weight_to_save
,
os
.
path
.
join
(
mm_projector_folder
,
f
'
{
current_folder
}
.bin'
))
else
:
torch
.
save
(
weight_to_save
,
os
.
path
.
join
(
output_dir
,
f
'mm_projector.bin'
))
return
if
trainer
.
deepspeed
:
torch
.
cuda
.
synchronize
()
trainer
.
save_model
(
output_dir
)
return
state_dict
=
trainer
.
model
.
state_dict
()
if
trainer
.
args
.
should_save
:
cpu_state_dict
=
{
key
:
value
.
cpu
()
for
key
,
value
in
state_dict
.
items
()
}
del
state_dict
trainer
.
_save
(
output_dir
,
state_dict
=
cpu_state_dict
)
# noqa
def
smart_tokenizer_and_embedding_resize
(
special_tokens_dict
:
Dict
,
tokenizer
:
transformers
.
PreTrainedTokenizer
,
model
:
transformers
.
PreTrainedModel
,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens
=
tokenizer
.
add_special_tokens
(
special_tokens_dict
)
model
.
resize_token_embeddings
(
len
(
tokenizer
))
if
num_new_tokens
>
0
:
input_embeddings
=
model
.
get_input_embeddings
().
weight
.
data
output_embeddings
=
model
.
get_output_embeddings
().
weight
.
data
input_embeddings_avg
=
input_embeddings
[:
-
num_new_tokens
].
mean
(
dim
=
0
,
keepdim
=
True
)
output_embeddings_avg
=
output_embeddings
[:
-
num_new_tokens
].
mean
(
dim
=
0
,
keepdim
=
True
)
input_embeddings
[
-
num_new_tokens
:]
=
input_embeddings_avg
output_embeddings
[
-
num_new_tokens
:]
=
output_embeddings_avg
def
train
():
global
local_rank
parser
=
transformers
.
HfArgumentParser
(
(
ModelArguments
,
DataArguments
,
TrainingArguments
))
model_args
,
data_args
,
training_args
=
parser
.
parse_args_into_dataclasses
()
local_rank
=
training_args
.
local_rank
compute_dtype
=
(
torch
.
float16
if
training_args
.
fp16
else
(
torch
.
bfloat16
if
training_args
.
bf16
else
torch
.
float32
))
bnb_model_from_pretrained_args
=
{}
if
training_args
.
bits
in
[
4
,
8
]:
from
transformers
import
BitsAndBytesConfig
bnb_model_from_pretrained_args
.
update
(
dict
(
device_map
=
{
""
:
training_args
.
device
},
load_in_4bit
=
training_args
.
bits
==
4
,
load_in_8bit
=
training_args
.
bits
==
8
,
quantization_config
=
BitsAndBytesConfig
(
load_in_4bit
=
training_args
.
bits
==
4
,
load_in_8bit
=
training_args
.
bits
==
8
,
llm_int8_threshold
=
6.0
,
llm_int8_has_fp16_weight
=
False
,
bnb_4bit_compute_dtype
=
compute_dtype
,
bnb_4bit_use_double_quant
=
training_args
.
double_quant
,
bnb_4bit_quant_type
=
training_args
.
quant_type
# {'fp4', 'nf4'}
)
))
if
'chatglm'
in
model_args
.
model_name_or_path
:
model
=
VTimeLLMChatGLMForCausalLM
.
from_pretrained
(
model_args
.
model_name_or_path
,
empty_init
=
False
,
device
=
'cuda'
)
else
:
model
=
VTimeLLMLlamaForCausalLM
.
from_pretrained
(
model_args
.
model_name_or_path
,
cache_dir
=
training_args
.
cache_dir
,
**
bnb_model_from_pretrained_args
)
model
.
config
.
use_cache
=
False
if
training_args
.
bits
in
[
4
,
8
]:
from
peft
import
prepare_model_for_kbit_training
model
.
config
.
torch_dtype
=
(
torch
.
float32
if
training_args
.
fp16
else
(
torch
.
bfloat16
if
training_args
.
bf16
else
torch
.
float32
))
model
=
prepare_model_for_kbit_training
(
model
,
use_gradient_checkpointing
=
training_args
.
gradient_checkpointing
)
if
training_args
.
gradient_checkpointing
:
if
hasattr
(
model
,
"enable_input_require_grads"
):
model
.
enable_input_require_grads
()
else
:
def
make_inputs_require_grad
(
module
,
input
,
output
):
output
.
requires_grad_
(
True
)
model
.
get_input_embeddings
().
register_forward_hook
(
make_inputs_require_grad
)
if
training_args
.
lora_enable
:
from
peft
import
LoraConfig
,
get_peft_model
lora_config
=
LoraConfig
(
r
=
training_args
.
lora_r
,
lora_alpha
=
training_args
.
lora_alpha
,
target_modules
=
find_all_linear_names
(
model
),
lora_dropout
=
training_args
.
lora_dropout
,
bias
=
training_args
.
lora_bias
,
task_type
=
"CAUSAL_LM"
,
)
if
training_args
.
bits
==
16
:
if
training_args
.
bf16
:
model
.
to
(
torch
.
bfloat16
)
if
training_args
.
fp16
:
model
.
to
(
torch
.
float16
)
# print_trainable_parameters(model)
if
training_args
.
training_stage
==
3
:
model
.
get_model
().
initialize_vision_modules
(
model_args
)
model
=
load_lora
(
model
,
model_args
.
stage2_path
)
rank0_print
(
'Merging LoRA weights...'
)
model
=
model
.
merge_and_unload
()
# print_trainable_parameters(model)
rank0_print
(
"Adding LoRA adapters..."
)
model
=
get_peft_model
(
model
,
lora_config
)
else
:
rank0_print
(
"Adding LoRA adapters..."
)
model
=
get_peft_model
(
model
,
lora_config
)
# print_trainable_parameters(model)
if
'chatglm'
in
model_args
.
model_name_or_path
:
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_args
.
model_name_or_path
,
trust_remote_code
=
True
)
else
:
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_args
.
model_name_or_path
,
cache_dir
=
training_args
.
cache_dir
,
model_max_length
=
training_args
.
model_max_length
,
padding_side
=
"right"
,
use_fast
=
False
,
)
tokenizer
.
pad_token
=
tokenizer
.
unk_token
if
model_args
.
version
in
conversation_lib
.
conv_templates
:
conversation_lib
.
default_conversation
=
conversation_lib
.
conv_templates
[
model_args
.
version
]
else
:
conversation_lib
.
default_conversation
=
conversation_lib
.
conv_templates
[
"vicuna_v1"
]
model
.
config
.
tune_mm_mlp_adapter
=
training_args
.
tune_mm_mlp_adapter
=
model_args
.
tune_mm_mlp_adapter
model
.
config
.
freeze_mm_mlp_adapter
=
training_args
.
freeze_mm_mlp_adapter
if
training_args
.
training_stage
!=
3
:
model
.
get_model
().
initialize_vision_modules
(
model_args
=
model_args
)
if
model_args
.
tune_mm_mlp_adapter
:
model
.
requires_grad_
(
False
)
for
p
in
model
.
get_model
().
mm_projector
.
parameters
():
p
.
requires_grad
=
True
if
training_args
.
freeze_mm_mlp_adapter
:
for
p
in
model
.
get_model
().
mm_projector
.
parameters
():
p
.
requires_grad
=
False
if
training_args
.
bits
in
[
4
,
8
]:
model
.
get_model
().
mm_projector
.
to
(
dtype
=
compute_dtype
,
device
=
training_args
.
device
)
if
training_args
.
bits
in
[
4
,
8
]:
from
peft.tuners.lora
import
LoraLayer
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
LoraLayer
):
if
training_args
.
bf16
:
module
=
module
.
to
(
torch
.
bfloat16
)
if
'norm'
in
name
:
module
=
module
.
to
(
torch
.
float32
)
if
'lm_head'
in
name
or
'embed_tokens'
in
name
:
if
hasattr
(
module
,
'weight'
):
if
training_args
.
bf16
and
module
.
weight
.
dtype
==
torch
.
float32
:
module
=
module
.
to
(
torch
.
bfloat16
)
data_module
=
make_supervised_data_module
(
tokenizer
=
tokenizer
,
data_args
=
data_args
)
trainer
=
VTimeLLMTrainer
(
model
=
model
,
tokenizer
=
tokenizer
,
args
=
training_args
,
**
data_module
)
if
list
(
pathlib
.
Path
(
training_args
.
output_dir
).
glob
(
"checkpoint-*"
)):
trainer
.
train
(
resume_from_checkpoint
=
True
)
else
:
trainer
.
train
()
trainer
.
save_state
()
model
.
config
.
use_cache
=
True
if
training_args
.
lora_enable
:
state_dict
=
get_peft_state_maybe_zero_3
(
model
.
named_parameters
(),
training_args
.
lora_bias
)
non_lora_state_dict
=
get_peft_state_non_lora_maybe_zero_3
(
model
.
named_parameters
()
)
if
training_args
.
local_rank
==
0
or
training_args
.
local_rank
==
-
1
:
model
.
config
.
save_pretrained
(
training_args
.
output_dir
)
model
.
save_pretrained
(
training_args
.
output_dir
,
state_dict
=
state_dict
)
torch
.
save
(
non_lora_state_dict
,
os
.
path
.
join
(
training_args
.
output_dir
,
'non_lora_trainables.bin'
))
else
:
safe_save_model_for_hf_trainer
(
trainer
=
trainer
,
output_dir
=
training_args
.
output_dir
)
if
__name__
==
"__main__"
:
train
()
vtimellm/train/train_mem.py
0 → 100755
View file @
fef630ee
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
import
os
root_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
".."
,
".."
)
print
(
root_dir
)
import
sys
sys
.
path
.
append
(
root_dir
)
from
llama_flash_attn_monkey_patch
import
replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn
()
from
train
import
train
if
__name__
==
"__main__"
:
train
()
vtimellm/train/vtimellm_trainer.py
0 → 100755
View file @
fef630ee
import
os
import
torch
from
transformers
import
Trainer
from
typing
import
Optional
def
maybe_zero_3
(
param
,
ignore_status
=
False
,
name
=
None
):
from
deepspeed
import
zero
from
deepspeed.runtime.zero.partition_parameters
import
ZeroParamStatus
if
hasattr
(
param
,
"ds_id"
):
if
param
.
ds_status
==
ZeroParamStatus
.
NOT_AVAILABLE
:
if
not
ignore_status
:
print
(
name
,
'no ignore status'
)
with
zero
.
GatheredParameters
([
param
]):
param
=
param
.
data
.
detach
().
cpu
().
clone
()
else
:
param
=
param
.
detach
().
cpu
().
clone
()
return
param
def
get_mm_adapter_state_maybe_zero_3
(
named_params
,
keys_to_match
):
to_return
=
{
k
:
t
for
k
,
t
in
named_params
if
any
(
key_match
in
k
for
key_match
in
keys_to_match
)}
to_return
=
{
k
:
maybe_zero_3
(
v
,
ignore_status
=
True
,
name
=
k
).
cpu
()
for
k
,
v
in
to_return
.
items
()}
return
to_return
class
VTimeLLMTrainer
(
Trainer
):
def
_save_checkpoint
(
self
,
model
,
trial
,
metrics
=
None
):
if
getattr
(
self
.
args
,
'tune_mm_mlp_adapter'
,
False
):
from
transformers.trainer_utils
import
PREFIX_CHECKPOINT_DIR
checkpoint_folder
=
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
self
.
state
.
global_step
}
"
run_dir
=
self
.
_get_output_dir
(
trial
=
trial
)
output_dir
=
os
.
path
.
join
(
run_dir
,
checkpoint_folder
)
# Only save Adapter
keys_to_match
=
[
'mm_projector'
]
if
getattr
(
self
.
args
,
"use_im_start_end"
,
False
):
keys_to_match
.
extend
([
'embed_tokens'
,
'embed_in'
])
weight_to_save
=
get_mm_adapter_state_maybe_zero_3
(
self
.
model
.
named_parameters
(),
keys_to_match
)
if
self
.
args
.
local_rank
==
0
or
self
.
args
.
local_rank
==
-
1
:
self
.
model
.
config
.
save_pretrained
(
output_dir
)
torch
.
save
(
weight_to_save
,
os
.
path
.
join
(
output_dir
,
f
'mm_projector.bin'
))
else
:
super
(
VTimeLLMTrainer
,
self
).
_save_checkpoint
(
model
,
trial
,
metrics
)
def
_save
(
self
,
output_dir
:
Optional
[
str
]
=
None
,
state_dict
=
None
):
if
getattr
(
self
.
args
,
'tune_mm_mlp_adapter'
,
False
):
pass
else
:
super
(
VTimeLLMTrainer
,
self
).
_save
(
output_dir
,
state_dict
)
vtimellm/utils.py
0 → 100644
View file @
fef630ee
import
datetime
import
logging
import
logging.handlers
import
os
import
sys
import
requests
from
vtimellm.constants
import
LOGDIR
server_error_msg
=
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg
=
"YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
handler
=
None
def
build_logger
(
logger_name
,
logger_filename
):
global
handler
formatter
=
logging
.
Formatter
(
fmt
=
"%(asctime)s | %(levelname)s | %(name)s | %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
)
# Set the format of root handlers
if
not
logging
.
getLogger
().
handlers
:
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
getLogger
().
handlers
[
0
].
setFormatter
(
formatter
)
# Redirect stdout and stderr to loggers
stdout_logger
=
logging
.
getLogger
(
"stdout"
)
stdout_logger
.
setLevel
(
logging
.
INFO
)
sl
=
StreamToLogger
(
stdout_logger
,
logging
.
INFO
)
sys
.
stdout
=
sl
stderr_logger
=
logging
.
getLogger
(
"stderr"
)
stderr_logger
.
setLevel
(
logging
.
ERROR
)
sl
=
StreamToLogger
(
stderr_logger
,
logging
.
ERROR
)
sys
.
stderr
=
sl
# Get logger
logger
=
logging
.
getLogger
(
logger_name
)
logger
.
setLevel
(
logging
.
INFO
)
# Add a file handler for all loggers
if
handler
is
None
:
os
.
makedirs
(
LOGDIR
,
exist_ok
=
True
)
filename
=
os
.
path
.
join
(
LOGDIR
,
logger_filename
)
handler
=
logging
.
handlers
.
TimedRotatingFileHandler
(
filename
,
when
=
'D'
,
utc
=
True
)
handler
.
setFormatter
(
formatter
)
for
name
,
item
in
logging
.
root
.
manager
.
loggerDict
.
items
():
if
isinstance
(
item
,
logging
.
Logger
):
item
.
addHandler
(
handler
)
return
logger
class
StreamToLogger
(
object
):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def
__init__
(
self
,
logger
,
log_level
=
logging
.
INFO
):
self
.
terminal
=
sys
.
stdout
self
.
logger
=
logger
self
.
log_level
=
log_level
self
.
linebuf
=
''
def
__getattr__
(
self
,
attr
):
return
getattr
(
self
.
terminal
,
attr
)
def
write
(
self
,
buf
):
temp_linebuf
=
self
.
linebuf
+
buf
self
.
linebuf
=
''
for
line
in
temp_linebuf
.
splitlines
(
True
):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if
line
[
-
1
]
==
'
\n
'
:
self
.
logger
.
log
(
self
.
log_level
,
line
.
rstrip
())
else
:
self
.
linebuf
+=
line
def
flush
(
self
):
if
self
.
linebuf
!=
''
:
self
.
logger
.
log
(
self
.
log_level
,
self
.
linebuf
.
rstrip
())
self
.
linebuf
=
''
def
disable_torch_init
():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import
torch
setattr
(
torch
.
nn
.
Linear
,
"reset_parameters"
,
lambda
self
:
None
)
setattr
(
torch
.
nn
.
LayerNorm
,
"reset_parameters"
,
lambda
self
:
None
)
def
violates_moderation
(
text
):
"""
Check whether the text violates OpenAI moderation API.
"""
url
=
"https://api.openai.com/v1/moderations"
headers
=
{
"Content-Type"
:
"application/json"
,
"Authorization"
:
"Bearer "
+
os
.
environ
[
"OPENAI_API_KEY"
]}
text
=
text
.
replace
(
"
\n
"
,
""
)
data
=
"{"
+
'"input": '
+
f
'"
{
text
}
"'
+
"}"
data
=
data
.
encode
(
"utf-8"
)
try
:
ret
=
requests
.
post
(
url
,
headers
=
headers
,
data
=
data
,
timeout
=
5
)
flagged
=
ret
.
json
()[
"results"
][
0
][
"flagged"
]
except
requests
.
exceptions
.
RequestException
as
e
:
flagged
=
False
except
KeyError
as
e
:
flagged
=
False
return
flagged
def
pretty_print_semaphore
(
semaphore
):
if
semaphore
is
None
:
return
"None"
return
f
"Semaphore(value=
{
semaphore
.
_value
}
, locked=
{
semaphore
.
locked
()
}
)"
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment