Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4c03347f
Unverified
Commit
4c03347f
authored
Feb 08, 2024
by
Frank Lee
Committed by
GitHub
Feb 08, 2024
Browse files
Merge pull request #5377 from hpcaitech/example/llama-npu
[llama] support npu for Colossal-LLaMA-2
parents
c53ddda8
084c9124
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
497 additions
and
733 deletions
+497
-733
applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
...ations/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
+10
-58
applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
...al-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
+321
-181
applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py
...s/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py
+1
-1
applications/Colossal-LLaMA-2/train.example.sh
applications/Colossal-LLaMA-2/train.example.sh
+1
-0
applications/Colossal-LLaMA-2/train.py
applications/Colossal-LLaMA-2/train.py
+105
-68
applications/Colossal-LLaMA-2/train_sft.example.sh
applications/Colossal-LLaMA-2/train_sft.example.sh
+2
-1
applications/Colossal-LLaMA-2/train_sft.py
applications/Colossal-LLaMA-2/train_sft.py
+0
-403
applications/ColossalEval/colossal_eval/models/chatglm.py
applications/ColossalEval/colossal_eval/models/chatglm.py
+5
-3
applications/ColossalEval/colossal_eval/models/huggingface.py
...ications/ColossalEval/colossal_eval/models/huggingface.py
+9
-8
applications/ColossalEval/examples/dataset_evaluation/inference.py
...ons/ColossalEval/examples/dataset_evaluation/inference.py
+4
-2
colossalai/booster/plugin/dp_plugin_base.py
colossalai/booster/plugin/dp_plugin_base.py
+12
-2
colossalai/booster/plugin/gemini_plugin.py
colossalai/booster/plugin/gemini_plugin.py
+12
-2
colossalai/booster/plugin/hybrid_parallel_plugin.py
colossalai/booster/plugin/hybrid_parallel_plugin.py
+12
-2
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+3
-2
No files found.
applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
View file @
4c03347f
#!/usr/bin/env python3
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
numpy
as
np
import
os
import
os
import
random
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Un
ion
,
Sequence
,
Optional
,
Iterator
,
Callable
from
typing
import
Dict
,
Iterator
,
List
,
Opt
ion
al
,
Sequence
,
Union
import
torch
import
torch
from
datasets
import
dataset_dict
,
load_from_disk
import
torch.nn.functional
as
F
from
datasets
import
Dataset
as
HFDataset
from
datasets
import
Dataset
as
HFDataset
from
torch.distributed
import
ProcessGroup
from
datasets
import
dataset_dict
,
load_from_disk
from
torch.distributed.distributed_c10d
import
_get_default_group
from
torch.utils.data
import
ConcatDataset
,
Dataset
,
DistributedSampler
from
torch.utils.data
import
ConcatDataset
,
Dataset
,
DataLoader
,
DistributedSampler
from
transformers.tokenization_utils
import
PreTrainedTokenizer
from
transformers.tokenization_utils
import
PreTrainedTokenizer
import
torch.nn.functional
as
F
DatasetType
=
Union
[
Dataset
,
ConcatDataset
,
dataset_dict
.
Dataset
]
DatasetType
=
Union
[
Dataset
,
ConcatDataset
,
dataset_dict
.
Dataset
]
PathType
=
Union
[
str
,
os
.
PathLike
]
PathType
=
Union
[
str
,
os
.
PathLike
]
...
@@ -62,6 +58,7 @@ class DataCollatorForSupervisedDataset(object):
...
@@ -62,6 +58,7 @@ class DataCollatorForSupervisedDataset(object):
tokenizer
:
PreTrainedTokenizer
tokenizer
:
PreTrainedTokenizer
max_length
:
int
=
4096
max_length
:
int
=
4096
ignore_index
:
int
=
-
100
ignore_index
:
int
=
-
100
padding
:
str
=
"max_length"
def
__call__
(
self
,
instances
:
Sequence
[
Dict
[
str
,
List
[
int
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
def
__call__
(
self
,
instances
:
Sequence
[
Dict
[
str
,
List
[
int
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
"""
"""
...
@@ -106,10 +103,11 @@ class DataCollatorForSupervisedDataset(object):
...
@@ -106,10 +103,11 @@ class DataCollatorForSupervisedDataset(object):
batch_first
=
True
,
batch_first
=
True
,
padding_value
=
self
.
ignore_index
,
padding_value
=
self
.
ignore_index
,
)
# (bsz, max_len)
)
# (bsz, max_len)
# pad to max
if
self
.
padding
==
"max_length"
:
to_pad
=
self
.
max_length
-
input_ids
.
size
(
1
)
# pad to max
input_ids
=
F
.
pad
(
input_ids
,
(
0
,
to_pad
),
value
=
self
.
tokenizer
.
pad_token_id
)
to_pad
=
self
.
max_length
-
input_ids
.
size
(
1
)
labels
=
F
.
pad
(
labels
,
(
0
,
to_pad
),
value
=
self
.
ignore_index
)
input_ids
=
F
.
pad
(
input_ids
,
(
0
,
to_pad
),
value
=
self
.
tokenizer
.
pad_token_id
)
labels
=
F
.
pad
(
labels
,
(
0
,
to_pad
),
value
=
self
.
ignore_index
)
elif
self
.
tokenizer
.
padding_side
==
"left"
:
elif
self
.
tokenizer
.
padding_side
==
"left"
:
reversed_input_ids
=
[
seq
.
flip
(
dims
=
(
0
,))
for
seq
in
batch_input_ids
]
reversed_input_ids
=
[
seq
.
flip
(
dims
=
(
0
,))
for
seq
in
batch_input_ids
]
reversed_input_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
reversed_input_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
...
@@ -171,49 +169,3 @@ class StatefulDistributedSampler(DistributedSampler):
...
@@ -171,49 +169,3 @@ class StatefulDistributedSampler(DistributedSampler):
def
set_start_index
(
self
,
start_index
:
int
)
->
None
:
def
set_start_index
(
self
,
start_index
:
int
)
->
None
:
self
.
start_index
=
start_index
self
.
start_index
=
start_index
def
setup_distributed_dataloader
(
dataset
:
DatasetType
,
batch_size
:
int
=
1
,
shuffle
:
bool
=
False
,
seed
:
int
=
1024
,
drop_last
:
bool
=
False
,
pin_memory
:
bool
=
False
,
num_workers
:
int
=
0
,
collate_fn
:
Callable
[[
Sequence
[
Dict
[
str
,
Union
[
str
,
List
[
int
]]]]],
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
**
kwargs
,
)
->
DataLoader
:
"""
Setup dataloader for distributed training.
"""
_kwargs
=
kwargs
.
copy
()
process_group
=
process_group
or
_get_default_group
()
sampler
=
StatefulDistributedSampler
(
dataset
=
dataset
,
num_replicas
=
process_group
.
size
(),
rank
=
process_group
.
rank
(),
shuffle
=
shuffle
,
seed
=
seed
,
drop_last
=
drop_last
,
)
# Deterministic dataloader
def
seed_worker
(
worker_id
:
int
)
->
None
:
worker_seed
=
seed
np
.
random
.
seed
(
worker_seed
)
torch
.
manual_seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
batch_size
,
sampler
=
sampler
,
num_workers
=
num_workers
,
collate_fn
=
collate_fn
,
pin_memory
=
pin_memory
,
drop_last
=
drop_last
,
worker_init_fn
=
seed_worker
,
**
_kwargs
,
)
applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
View file @
4c03347f
This diff is collapsed.
Click to expand it.
applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py
View file @
4c03347f
...
@@ -17,7 +17,7 @@ import torch
...
@@ -17,7 +17,7 @@ import torch
def
unwrap
(
model
):
def
unwrap
(
model
):
if
hasattr
(
model
,
"module"
):
if
hasattr
(
model
,
"module"
):
return
unwrap_model
(
model
.
module
)
return
model
.
unwrap
(
)
else
:
else
:
return
model
return
model
...
...
applications/Colossal-LLaMA-2/train.example.sh
View file @
4c03347f
...
@@ -42,3 +42,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.
...
@@ -42,3 +42,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.
--warmup_steps
100
\
--warmup_steps
100
\
--use_grad_checkpoint
\
--use_grad_checkpoint
\
--use_flash_attn
\
--use_flash_attn
\
--pad_token
"unk"
applications/Colossal-LLaMA-2/train.py
View file @
4c03347f
#!/usr/bin/env python3
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""
"""
Continual Pre-training
of
LLaMA-2 developed by Colossal-AI Team
Continual Pre-training
/Supervised fine-tuning of Colossal-
LLaMA-2 developed by Colossal-AI Team
"""
"""
import
argparse
import
argparse
...
@@ -16,22 +16,24 @@ from colossal_llama2.dataset.loader import (
...
@@ -16,22 +16,24 @@ from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset
,
DataCollatorForSupervisedDataset
,
StatefulDistributedSampler
,
StatefulDistributedSampler
,
load_tokenized_dataset
,
load_tokenized_dataset
,
setup_distributed_dataloader
,
)
)
from
colossal_llama2.utils.ckpt_io
import
load_checkpoint
,
save_checkpoint
from
colossal_llama2.utils.ckpt_io
import
load_checkpoint
,
save_checkpoint
from
colossal_llama2.utils.flash_attention_patch
import
replace_with_flash_attention
from
colossal_llama2.utils.flash_attention_patch
import
replace_with_flash_attention
from
colossal_llama2.utils.froze
import
freeze_non_embeds_parameters
from
colossal_llama2.utils.froze
import
freeze_non_embeds_parameters
from
colossal_llama2.utils.neftune_patch
import
activate_neftune
,
deactivate_neftune
from
torch.utils.tensorboard
import
SummaryWriter
from
torch.utils.tensorboard
import
SummaryWriter
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
,
LlamaTokenizer
from
transformers
import
LlamaForCausalLM
,
LlamaTokenizer
import
colossalai
import
colossalai
from
colossalai.accelerator
import
get_accelerator
from
colossalai.booster
import
Booster
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
GeminiPlugin
,
HybridParallelPlugin
,
LowLevelZeroPlugin
from
colossalai.booster.plugin
import
GeminiPlugin
,
HybridParallelPlugin
,
LowLevelZeroPlugin
from
colossalai.cluster
import
DistCoordinator
from
colossalai.cluster
import
DistCoordinator
from
colossalai.lazy
import
LazyInitContext
from
colossalai.lazy
import
LazyInitContext
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.utils
import
get_current_device
def
get_model_numel
(
model
:
torch
.
nn
.
Module
)
->
int
:
def
get_model_numel
(
model
:
torch
.
nn
.
Module
)
->
int
:
...
@@ -83,6 +85,7 @@ def main() -> None:
...
@@ -83,6 +85,7 @@ def main() -> None:
parser
.
add_argument
(
"--tensorboard_dir"
,
type
=
str
,
default
=
"logs_dir"
,
help
=
"Tensorboard directory"
)
parser
.
add_argument
(
"--tensorboard_dir"
,
type
=
str
,
default
=
"logs_dir"
,
help
=
"Tensorboard directory"
)
parser
.
add_argument
(
"--config_file"
,
type
=
str
,
default
=
"config_file"
,
help
=
"Config file"
)
parser
.
add_argument
(
"--config_file"
,
type
=
str
,
default
=
"config_file"
,
help
=
"Config file"
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
1
,
help
=
"Number of training epochs"
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
1
,
help
=
"Number of training epochs"
)
parser
.
add_argument
(
"--accumulation_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of accumulation steps"
)
parser
.
add_argument
(
"--micro_batch_size"
,
type
=
int
,
default
=
2
,
help
=
"Batch size of each process"
)
parser
.
add_argument
(
"--micro_batch_size"
,
type
=
int
,
default
=
2
,
help
=
"Batch size of each process"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
3e-4
,
help
=
"Learning rate"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
3e-4
,
help
=
"Learning rate"
)
parser
.
add_argument
(
"--max_length"
,
type
=
int
,
default
=
4096
,
help
=
"Model max length"
)
parser
.
add_argument
(
"--max_length"
,
type
=
int
,
default
=
4096
,
help
=
"Model max length"
)
...
@@ -108,6 +111,12 @@ def main() -> None:
...
@@ -108,6 +111,12 @@ def main() -> None:
default
=
False
,
default
=
False
,
help
=
"Use flash-attention"
,
help
=
"Use flash-attention"
,
)
)
parser
.
add_argument
(
"--use_neft"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use NEFTune"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--freeze_non_embeds_params"
,
"--freeze_non_embeds_params"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
@@ -116,6 +125,8 @@ def main() -> None:
...
@@ -116,6 +125,8 @@ def main() -> None:
)
)
parser
.
add_argument
(
"--tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--zero"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--zero"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--pad_token"
,
choices
=
[
"eos"
,
"unk"
],
default
=
"eos"
)
parser
.
add_argument
(
"--padding_mode"
,
choices
=
[
"max_length"
,
"longest"
],
default
=
"max_length"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
with
open
(
args
.
config_file
,
"w"
)
as
f
:
with
open
(
args
.
config_file
,
"w"
)
as
f
:
...
@@ -125,6 +136,7 @@ def main() -> None:
...
@@ -125,6 +136,7 @@ def main() -> None:
# Initialize Distributed Training
# Initialize Distributed Training
# ==============================
# ==============================
colossalai
.
launch_from_torch
({})
colossalai
.
launch_from_torch
({})
accelerator
=
get_accelerator
()
coordinator
=
DistCoordinator
()
coordinator
=
DistCoordinator
()
# ==============================
# ==============================
...
@@ -182,7 +194,10 @@ def main() -> None:
...
@@ -182,7 +194,10 @@ def main() -> None:
# Initialize Tokenizer, Dataset, Collator and Dataloader
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
# ======================================================
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
args
.
pretrained
)
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
args
.
pretrained
)
tokenizer
.
pad_token
=
tokenizer
.
unk_token
if
args
.
pad_token
==
"eos"
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
elif
args
.
pad_token
==
"unk"
:
tokenizer
.
pad_token
=
tokenizer
.
unk_token
tokenizer
.
add_bos_token
=
False
tokenizer
.
add_bos_token
=
False
tokenizer
.
add_eos_token
=
False
tokenizer
.
add_eos_token
=
False
...
@@ -193,38 +208,36 @@ def main() -> None:
...
@@ -193,38 +208,36 @@ def main() -> None:
coordinator
.
print_on_master
(
f
"Load dataset:
{
args
.
dataset
}
"
)
coordinator
.
print_on_master
(
f
"Load dataset:
{
args
.
dataset
}
"
)
dataset
=
load_tokenized_dataset
(
dataset_paths
=
args
.
dataset
,
mode
=
"train"
)
dataset
=
load_tokenized_dataset
(
dataset_paths
=
args
.
dataset
,
mode
=
"train"
)
data_collator
=
DataCollatorForSupervisedDataset
(
tokenizer
=
tokenizer
,
max_length
=
args
.
max_length
)
data_collator
=
DataCollatorForSupervisedDataset
(
dataloader
=
setup_distributed_dataloader
(
tokenizer
=
tokenizer
,
max_length
=
args
.
max_length
,
padding
=
args
.
padding_mode
)
dataloader
=
plugin
.
prepare_dataloader
(
dataset
=
dataset
,
dataset
=
dataset
,
batch_size
=
args
.
micro_batch_size
,
batch_size
=
args
.
micro_batch_size
,
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
True
,
drop_last
=
True
,
collate_fn
=
data_collator
,
collate_fn
=
data_collator
,
distributed_sampler_cls
=
StatefulDistributedSampler
,
)
)
coordinator
.
print_on_master
(
coordinator
.
print_on_master
(
f
"Max
CUDA
memory after data loader:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
f
"Max
device
memory after data loader:
{
accelerator
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
)
# ======================================================
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# ======================================================
init_ctx
=
(
# colossalai has changed api for get_current_device in 0.3.4 version or newer
LazyInitContext
(
default_device
=
get_current_device
())
try
:
if
isinstance
(
plugin
,
(
GeminiPlugin
,
HybridParallelPlugin
))
from
colossalai.accelerator
import
get_accelerator
else
nullcontext
()
)
current_device
=
get_accelerator
().
get_current_device
()
except
:
from
colossalai.utils
import
get_current_device
current_device
=
get_current_device
()
init_ctx
=
LazyInitContext
(
default_device
=
current_device
)
if
isinstance
(
plugin
,
(
GeminiPlugin
,))
else
nullcontext
()
with
init_ctx
:
with
init_ctx
:
model
=
LlamaForCausalLM
(
LlamaConfig
.
from_pretrained
(
args
.
pretrained
)
)
model
=
LlamaForCausalLM
.
from_pretrained
(
args
.
pretrained
)
# Freeze part of parameters.
# Freeze part of parameters.
if
args
.
freeze_non_embeds_params
:
if
args
.
freeze_non_embeds_params
:
freeze_non_embeds_parameters
(
model
=
model
)
freeze_non_embeds_parameters
(
model
=
model
)
# this is essential, otherwise the grad checkpoint will not work.
model
.
train
()
if
args
.
use_grad_checkpoint
:
if
args
.
use_grad_checkpoint
:
model
.
gradient_checkpointing_enable
()
model
.
gradient_checkpointing_enable
()
...
@@ -246,12 +259,14 @@ def main() -> None:
...
@@ -246,12 +259,14 @@ def main() -> None:
adamw_mode
=
True
,
adamw_mode
=
True
,
)
)
if
args
.
warmup_steps
is
None
:
args
.
warmup_steps
=
int
(
args
.
num_epochs
*
0.025
*
(
len
(
dataloader
)
//
args
.
accumulation_steps
))
coordinator
.
print_on_master
(
f
"Warmup steps is set to
{
args
.
warmup_steps
}
"
)
lr_scheduler
=
CosineAnnealingWarmupLR
(
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
optimizer
=
optimizer
,
total_steps
=
args
.
num_epochs
*
len
(
dataloader
),
total_steps
=
args
.
num_epochs
*
(
len
(
dataloader
)
//
args
.
accumulation_steps
),
warmup_steps
=
args
.
warmup_steps
warmup_steps
=
args
.
warmup_steps
,
if
args
.
warmup_steps
is
not
None
else
int
(
args
.
num_epochs
*
len
(
dataloader
)
*
0.025
),
eta_min
=
0.1
*
args
.
lr
,
eta_min
=
0.1
*
args
.
lr
,
)
)
...
@@ -267,11 +282,9 @@ def main() -> None:
...
@@ -267,11 +282,9 @@ def main() -> None:
torch
.
set_default_dtype
(
torch
.
float
)
torch
.
set_default_dtype
(
torch
.
float
)
if
args
.
load_checkpoint
is
None
:
coordinator
.
print_on_master
(
coordinator
.
print_on_master
(
f
"Load pretrained model checkpoint from
{
args
.
pretrained
}
"
)
f
"Booster init max device memory:
{
accelerator
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
booster
.
load_model
(
model
,
args
.
pretrained
,
strict
=
False
)
)
coordinator
.
print_on_master
(
f
"Booster init max CUDA memory:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
coordinator
.
print_on_master
(
f
"Booster init max CPU memory:
{
resource
.
getrusage
(
resource
.
RUSAGE_SELF
).
ru_maxrss
/
1024
:.
2
f
}
MB"
f
"Booster init max CPU memory:
{
resource
.
getrusage
(
resource
.
RUSAGE_SELF
).
ru_maxrss
/
1024
:.
2
f
}
MB"
)
)
...
@@ -298,85 +311,109 @@ def main() -> None:
...
@@ -298,85 +311,109 @@ def main() -> None:
coordinator
.
print_on_master
(
f
"Loaded sample at index
{
sampler_start_idx
}
"
)
coordinator
.
print_on_master
(
f
"Loaded sample at index
{
sampler_start_idx
}
"
)
coordinator
.
print_on_master
(
coordinator
.
print_on_master
(
f
"Checkpoint loaded max
CUDA
memory:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
f
"Checkpoint loaded max
device
memory:
{
accelerator
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
)
coordinator
.
print_on_master
(
coordinator
.
print_on_master
(
f
"Checkpoint loaded
CUDA
memory:
{
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
f
"Checkpoint loaded
device
memory:
{
accelerator
.
memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
)
coordinator
.
print_on_master
(
coordinator
.
print_on_master
(
f
"Checkpoint loaded max CPU memory:
{
resource
.
getrusage
(
resource
.
RUSAGE_SELF
).
ru_maxrss
/
1024
:.
2
f
}
MB"
f
"Checkpoint loaded max CPU memory:
{
resource
.
getrusage
(
resource
.
RUSAGE_SELF
).
ru_maxrss
/
1024
:.
2
f
}
MB"
)
)
num_steps_per_epoch
=
len
(
dataloader
)
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Activate NEFTune."
)
model
,
handle
=
activate_neftune
(
model
)
num_steps_per_epoch
=
len
(
dataloader
)
//
args
.
accumulation_steps
# If resume training, set the sampler start index to the correct value
# If resume training, set the sampler start index to the correct value
assert
isinstance
(
dataloader
.
sampler
,
StatefulDistributedSampler
)
assert
isinstance
(
dataloader
.
sampler
,
StatefulDistributedSampler
)
dataloader
.
sampler
.
set_start_index
(
start_index
=
sampler_start_idx
)
dataloader
.
sampler
.
set_start_index
(
start_index
=
sampler_start_idx
)
for
epoch
in
range
(
start_epoch
,
args
.
num_epochs
):
for
epoch
in
range
(
start_epoch
,
args
.
num_epochs
):
dataloader
.
sampler
.
set_epoch
(
epoch
=
epoch
)
dataloader
.
sampler
.
set_epoch
(
epoch
=
epoch
)
with
tqdm
(
pbar
=
tqdm
(
iterable
=
enumerate
(
dataloader
,
start
=
start_step
),
desc
=
f
"Epoch
{
epoch
}
"
,
desc
=
f
"Epoch
{
epoch
}
"
,
disable
=
not
coordinator
.
is_master
(),
disable
=
not
coordinator
.
is_master
(),
total
=
num_steps_per_epoch
,
total
=
num_steps_per_epoch
,
initial
=
start_step
,
initial
=
start_step
//
args
.
accumulation_steps
,
)
as
pbar
:
)
for
step
,
batch
in
pbar
:
total_loss
=
torch
.
tensor
(
0.0
,
device
=
get_current_device
())
batch
=
{
k
:
v
.
to
(
current_device
)
for
k
,
v
in
batch
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)}
for
step
,
batch
in
enumerate
(
dataloader
,
start
=
start_step
):
batch
=
{
k
:
v
.
to
(
get_current_device
())
for
k
,
v
in
batch
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)}
batch_output
=
model
(
**
batch
)
batch_output
=
model
(
**
batch
)
loss
=
batch_output
.
loss
loss
=
batch_output
.
loss
/
args
.
accumulation_steps
total_loss
.
add_
(
loss
.
data
)
booster
.
backward
(
loss
=
loss
,
optimizer
=
optimizer
)
booster
.
backward
(
loss
=
loss
,
optimizer
=
optimizer
)
if
(
step
+
1
)
%
args
.
accumulation_steps
==
0
:
optimizer
.
step
()
optimizer
.
step
()
lr_scheduler
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
all_reduce_mean
(
tensor
=
loss
)
all_reduce_mean
(
tensor
=
total_
loss
)
pbar
.
set_postfix
({
"Loss"
:
f
"
{
loss
.
item
():.
4
f
}
"
})
pbar
.
set_postfix
({
"Loss"
:
f
"
{
total_
loss
.
item
():.
4
f
}
"
})
if
coordinator
.
is_master
():
if
coordinator
.
is_master
():
global_step
=
epoch
*
num_steps_per_epoch
+
step
global_step
=
(
epoch
*
num_steps_per_epoch
)
+
(
step
+
1
)
//
args
.
accumulation_steps
writer
.
add_scalar
(
tag
=
"Loss"
,
scalar_value
=
loss
.
item
(),
global_step
=
global_step
)
writer
.
add_scalar
(
tag
=
"Loss"
,
scalar_value
=
total_
loss
.
item
(),
global_step
=
global_step
)
writer
.
add_scalar
(
writer
.
add_scalar
(
tag
=
"Learning Rate"
,
tag
=
"Learning Rate"
,
scalar_value
=
lr_scheduler
.
get_last_lr
()[
0
],
scalar_value
=
lr_scheduler
.
get_last_lr
()[
0
],
global_step
=
global_step
,
global_step
=
global_step
,
)
)
# Save modeling.
total_loss
.
fill_
(
0.0
)
pbar
.
update
()
if
(
args
.
save_interval
>
0
and
(
step
+
1
)
%
args
.
save_interval
==
0
)
or
(
step
+
1
)
==
len
(
dataloader
):
# Save modeling.
coordinator
.
print_on_master
(
"
\n
Start saving model checkpoint with running states"
)
save_checkpoint
(
if
(
args
.
save_interval
>
0
and
(
step
+
1
)
%
(
args
.
save_interval
*
args
.
accumulation_steps
)
==
0
)
or
(
save_dir
=
args
.
save_dir
,
step
+
1
booster
=
booster
,
)
==
len
(
dataloader
):
model
=
model
,
coordinator
.
print_on_master
(
"
\n
Start saving model checkpoint with running states"
)
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
if
args
.
use_neft
:
epoch
=
epoch
,
coordinator
.
print_on_master
(
"Deactivate NEFTune before saving model."
)
step
=
step
+
1
,
deactivate_neftune
(
model
,
handle
)
batch_size
=
args
.
micro_batch_size
,
coordinator
=
coordinator
,
accelerator
.
empty_cache
()
)
save_checkpoint
(
coordinator
.
print_on_master
(
save_dir
=
args
.
save_dir
,
f
"Saved checkpoint at epoch
{
epoch
}
step
{
step
+
1
}
at folder
{
args
.
save_dir
}
"
booster
=
booster
,
)
model
=
model
,
optimizer
=
optimizer
,
# Delete CUDA cache.
lr_scheduler
=
lr_scheduler
,
# del batch, batch_labels, batch_output, loss
epoch
=
epoch
,
torch
.
cuda
.
empty_cache
()
step
=
step
+
1
,
batch_size
=
args
.
micro_batch_size
,
coordinator
=
coordinator
,
)
coordinator
.
print_on_master
(
f
"Saved checkpoint at epoch
{
epoch
}
step
{
step
+
1
}
at folder
{
args
.
save_dir
}
"
)
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Activate NEFTune."
)
model
,
handle
=
activate_neftune
(
model
)
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator
.
empty_cache
()
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader
.
sampler
.
set_start_index
(
start_index
=
0
)
dataloader
.
sampler
.
set_start_index
(
start_index
=
0
)
start_step
=
0
start_step
=
0
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Deactivate NEFTune."
)
deactivate_neftune
(
model
,
handle
)
# Final save.
# Final save.
coordinator
.
print_on_master
(
"Start saving final model checkpoint"
)
coordinator
.
print_on_master
(
"Start saving final model checkpoint"
)
booster
.
save_model
(
model
,
os
.
path
.
join
(
args
.
save_dir
,
"modeling"
),
shard
=
True
)
booster
.
save_model
(
model
,
os
.
path
.
join
(
args
.
save_dir
,
"modeling"
),
shard
=
True
)
coordinator
.
print_on_master
(
f
"Saved final model checkpoint at epoch
{
epoch
}
at folder
{
args
.
save_dir
}
"
)
coordinator
.
print_on_master
(
f
"Saved final model checkpoint at epoch
{
epoch
}
at folder
{
args
.
save_dir
}
"
)
coordinator
.
print_on_master
(
f
"Max
CUDA
memory usage:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Max
device
memory usage:
{
accelerator
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
applications/Colossal-LLaMA-2/train_sft.example.sh
View file @
4c03347f
...
@@ -25,7 +25,7 @@ SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
...
@@ -25,7 +25,7 @@ SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
TENSORBOARD_DIR
=
"
${
PARENT_TENSORBOARD_DIR
}${
FULL_PROJECT_NAME
}
"
TENSORBOARD_DIR
=
"
${
PARENT_TENSORBOARD_DIR
}${
FULL_PROJECT_NAME
}
"
CONFIG_FILE
=
"
${
PARENT_CONFIG_FILE
}${
FULL_PROJECT_NAME
}
.json"
CONFIG_FILE
=
"
${
PARENT_CONFIG_FILE
}${
FULL_PROJECT_NAME
}
.json"
colossalai run
--nproc_per_node
8
--hostfile
hostfile
--master_port
30013 train
_sft
.py
\
colossalai run
--nproc_per_node
8
--hostfile
hostfile
--master_port
30013 train.py
\
--pretrained
$PRETRAINED_MODEL_PATH
\
--pretrained
$PRETRAINED_MODEL_PATH
\
--dataset
${
dataset
[@]
}
\
--dataset
${
dataset
[@]
}
\
--plugin
"zero2"
\
--plugin
"zero2"
\
...
@@ -44,3 +44,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_
...
@@ -44,3 +44,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_
--use_grad_checkpoint
\
--use_grad_checkpoint
\
--use_flash_attn
\
--use_flash_attn
\
--use_neft
\
--use_neft
\
--pad_token
"eos"
applications/Colossal-LLaMA-2/train_sft.py
deleted
100644 → 0
View file @
c53ddda8
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team
"""
import
argparse
import
json
import
os
import
resource
from
contextlib
import
nullcontext
import
torch
import
torch.distributed
as
dist
from
colossal_llama2.dataset.loader
import
(
DataCollatorForSupervisedDataset
,
StatefulDistributedSampler
,
load_tokenized_dataset
,
setup_distributed_dataloader
,
)
from
colossal_llama2.utils.ckpt_io
import
load_checkpoint
,
save_checkpoint
from
colossal_llama2.utils.flash_attention_patch
import
replace_with_flash_attention
from
colossal_llama2.utils.froze
import
freeze_non_embeds_parameters
from
colossal_llama2.utils.neftune_patch
import
activate_neftune
,
deactivate_neftune
from
torch.utils.tensorboard
import
SummaryWriter
from
tqdm
import
tqdm
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
,
LlamaTokenizer
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
GeminiPlugin
,
HybridParallelPlugin
,
LowLevelZeroPlugin
from
colossalai.cluster
import
DistCoordinator
from
colossalai.lazy
import
LazyInitContext
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.utils
import
get_current_device
def
get_model_numel
(
model
:
torch
.
nn
.
Module
)
->
int
:
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
def
format_numel_str
(
numel
:
int
)
->
str
:
B
=
1024
**
3
M
=
1024
**
2
K
=
1024
if
numel
>=
B
:
return
f
"
{
numel
/
B
:.
2
f
}
B"
elif
numel
>=
M
:
return
f
"
{
numel
/
M
:.
2
f
}
M"
elif
numel
>=
K
:
return
f
"
{
numel
/
K
:.
2
f
}
K"
else
:
return
f
"
{
numel
}
"
def
all_reduce_mean
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dist
.
all_reduce
(
tensor
=
tensor
,
op
=
dist
.
ReduceOp
.
SUM
)
tensor
.
div_
(
dist
.
get_world_size
())
return
tensor
def
main
()
->
None
:
# ==============================
# Parse Arguments
# ==============================
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--pretrained"
,
type
=
str
,
default
=
None
,
help
=
"Address of the pre-trained modeling"
,
)
parser
.
add_argument
(
"--dataset"
,
nargs
=
"+"
,
default
=
[])
parser
.
add_argument
(
"--plugin"
,
type
=
str
,
default
=
"gemini"
,
choices
=
[
"gemini"
,
"gemini_auto"
,
"zero2"
,
"zero2_cpu"
,
"3d"
],
help
=
"Choose which plugin to use"
,
)
parser
.
add_argument
(
"--load_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Load checkpoint"
)
parser
.
add_argument
(
"--save_interval"
,
type
=
int
,
default
=
1000
,
help
=
"Save interval"
)
parser
.
add_argument
(
"--save_dir"
,
type
=
str
,
default
=
"checkpoint_dir"
,
help
=
"Checkpoint directory"
)
parser
.
add_argument
(
"--tensorboard_dir"
,
type
=
str
,
default
=
"logs_dir"
,
help
=
"Tensorboard directory"
)
parser
.
add_argument
(
"--config_file"
,
type
=
str
,
default
=
"config_file"
,
help
=
"Config file"
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
1
,
help
=
"Number of training epochs"
)
parser
.
add_argument
(
"--accumulation_steps"
,
type
=
int
,
default
=
8
,
help
=
"Number of accumulation steps"
)
parser
.
add_argument
(
"--micro_batch_size"
,
type
=
int
,
default
=
2
,
help
=
"Batch size of each process"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
3e-4
,
help
=
"Learning rate"
)
parser
.
add_argument
(
"--max_length"
,
type
=
int
,
default
=
4096
,
help
=
"Model max length"
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
"fp16"
,
choices
=
[
"fp16"
,
"bf16"
],
help
=
"Mixed precision"
,
)
parser
.
add_argument
(
"--grad_clip"
,
type
=
float
,
default
=
1.0
,
help
=
"Gradient clipping value"
)
parser
.
add_argument
(
"--weight_decay"
,
type
=
float
,
default
=
0.1
,
help
=
"Weight decay"
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
None
,
help
=
"Warmup steps"
)
parser
.
add_argument
(
"--use_grad_checkpoint"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use gradient checkpointing"
,
)
parser
.
add_argument
(
"--use_flash_attn"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use flash-attention"
,
)
parser
.
add_argument
(
"--use_neft"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use NEFTune"
,
)
parser
.
add_argument
(
"--freeze_non_embeds_params"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Freeze non embeddings parameters"
,
)
parser
.
add_argument
(
"--tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--zero"
,
type
=
int
,
default
=
1
)
args
=
parser
.
parse_args
()
with
open
(
args
.
config_file
,
"w"
)
as
f
:
json
.
dump
(
args
.
__dict__
,
f
,
indent
=
4
)
# ==============================
# Initialize Distributed Training
# ==============================
colossalai
.
launch_from_torch
({})
coordinator
=
DistCoordinator
()
# ==============================
# Initialize Tensorboard
# ==============================
if
coordinator
.
is_master
():
os
.
makedirs
(
args
.
tensorboard_dir
,
exist_ok
=
True
)
writer
=
SummaryWriter
(
args
.
tensorboard_dir
)
# ==============================
# Initialize Booster
# ==============================
if
args
.
plugin
==
"gemini"
:
plugin
=
GeminiPlugin
(
precision
=
args
.
mixed_precision
,
initial_scale
=
2
**
16
,
max_norm
=
args
.
grad_clip
,
)
elif
args
.
plugin
==
"gemini_auto"
:
plugin
=
GeminiPlugin
(
precision
=
args
.
mixed_precision
,
placement_policy
=
"auto"
,
initial_scale
=
2
**
16
,
max_norm
=
args
.
grad_clip
,
)
elif
args
.
plugin
==
"zero2"
:
plugin
=
LowLevelZeroPlugin
(
stage
=
2
,
precision
=
args
.
mixed_precision
,
initial_scale
=
2
**
16
,
max_norm
=
args
.
grad_clip
,
)
elif
args
.
plugin
==
"zero2_cpu"
:
plugin
=
LowLevelZeroPlugin
(
stage
=
2
,
precision
=
args
.
mixed_precision
,
initial_scale
=
2
**
16
,
cpu_offload
=
True
,
max_norm
=
args
.
grad_clip
,
)
elif
args
.
plugin
==
"3d"
:
plugin
=
HybridParallelPlugin
(
tp_size
=
args
.
tp
,
pp_size
=
1
,
zero_stage
=
args
.
zero
,
max_norm
=
args
.
grad_clip
,
precision
=
args
.
mixed_precision
,
)
else
:
raise
ValueError
(
f
"Unknown plugin
{
args
.
plugin
}
"
)
booster
=
Booster
(
plugin
=
plugin
)
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
args
.
pretrained
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
tokenizer
.
add_bos_token
=
False
tokenizer
.
add_eos_token
=
False
coordinator
.
print_on_master
(
f
"Configuration file will be saved at:
{
args
.
config_file
}
"
)
coordinator
.
print_on_master
(
f
"Tensorboard logs will be saved at:
{
args
.
tensorboard_dir
}
"
)
coordinator
.
print_on_master
(
f
"Model checkpoint will be saved at:
{
args
.
save_dir
}
"
)
coordinator
.
print_on_master
(
f
"Load dataset:
{
args
.
dataset
}
"
)
dataset
=
load_tokenized_dataset
(
dataset_paths
=
args
.
dataset
,
mode
=
"train"
)
data_collator
=
DataCollatorForSupervisedDataset
(
tokenizer
=
tokenizer
,
max_length
=
args
.
max_length
)
dataloader
=
setup_distributed_dataloader
(
dataset
=
dataset
,
batch_size
=
args
.
micro_batch_size
,
shuffle
=
True
,
drop_last
=
True
,
collate_fn
=
data_collator
,
)
coordinator
.
print_on_master
(
f
"Max CUDA memory after data loader:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
init_ctx
=
(
LazyInitContext
(
default_device
=
get_current_device
())
if
isinstance
(
plugin
,
(
GeminiPlugin
,))
else
nullcontext
()
)
with
init_ctx
:
model
=
LlamaForCausalLM
(
LlamaConfig
.
from_pretrained
(
args
.
pretrained
))
# Freeze part of parameters.
if
args
.
freeze_non_embeds_params
:
freeze_non_embeds_parameters
(
model
=
model
)
if
args
.
use_grad_checkpoint
:
model
.
gradient_checkpointing_enable
()
coordinator
.
print_on_master
(
msg
=
"Gradient checkpointing enabled successfully"
)
if
args
.
use_flash_attn
:
replace_with_flash_attention
(
model
=
model
)
coordinator
.
print_on_master
(
msg
=
"Flash-attention enabled successfully"
)
model_numel
=
get_model_numel
(
model
)
coordinator
.
print_on_master
(
f
"Model params:
{
format_numel_str
(
model_numel
)
}
"
)
optimizer
=
HybridAdam
(
model_params
=
filter
(
lambda
p
:
p
.
requires_grad
,
model
.
parameters
())
if
args
.
freeze_non_embeds_params
else
model
.
parameters
(),
lr
=
args
.
lr
,
betas
=
(
0.9
,
0.95
),
weight_decay
=
args
.
weight_decay
,
adamw_mode
=
True
,
)
if
args
.
warmup_steps
is
None
:
args
.
warmup_steps
=
int
(
args
.
num_epochs
*
0.025
*
(
len
(
dataloader
)
//
args
.
accumulation_steps
))
coordinator
.
print_on_master
(
f
"Warmup steps is set to
{
args
.
warmup_steps
}
"
)
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
args
.
num_epochs
*
(
len
(
dataloader
)
//
args
.
accumulation_steps
),
warmup_steps
=
args
.
warmup_steps
,
eta_min
=
0.1
*
args
.
lr
,
)
# Flash attention will be disabled because it does NOT support fp32.
default_dtype
=
torch
.
float16
if
args
.
mixed_precision
==
"fp16"
else
torch
.
bfloat16
torch
.
set_default_dtype
(
default_dtype
)
model
,
optimizer
,
_
,
dataloader
,
lr_scheduler
=
booster
.
boost
(
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
dataloader
=
dataloader
,
)
torch
.
set_default_dtype
(
torch
.
float
)
if
args
.
load_checkpoint
is
None
:
coordinator
.
print_on_master
(
f
"Load pretrained model checkpoint from
{
args
.
pretrained
}
"
)
booster
.
load_model
(
model
,
args
.
pretrained
,
strict
=
False
)
coordinator
.
print_on_master
(
f
"Booster init max CUDA memory:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Booster init max CPU memory:
{
resource
.
getrusage
(
resource
.
RUSAGE_SELF
).
ru_maxrss
/
1024
:.
2
f
}
MB"
)
start_epoch
=
0
start_step
=
0
sampler_start_idx
=
0
if
args
.
load_checkpoint
is
not
None
:
if
"modeling"
in
args
.
load_checkpoint
:
coordinator
.
print_on_master
(
f
"Continued pretrain from checkpoint
{
args
.
load_checkpoint
}
"
)
booster
.
load_model
(
model
,
args
.
load_checkpoint
)
else
:
coordinator
.
print_on_master
(
f
"Load model checkpoint from
{
args
.
load_checkpoint
}
"
)
start_epoch
,
start_step
,
sampler_start_idx
=
load_checkpoint
(
load_dir
=
args
.
load_checkpoint
,
booster
=
booster
,
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
)
coordinator
.
print_on_master
(
f
"Loaded checkpoint
{
args
.
load_checkpoint
}
at epoch
{
start_epoch
}
step
{
start_step
}
"
)
coordinator
.
print_on_master
(
f
"Loaded sample at index
{
sampler_start_idx
}
"
)
coordinator
.
print_on_master
(
f
"Checkpoint loaded max CUDA memory:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Checkpoint loaded CUDA memory:
{
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Checkpoint loaded max CPU memory:
{
resource
.
getrusage
(
resource
.
RUSAGE_SELF
).
ru_maxrss
/
1024
:.
2
f
}
MB"
)
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Activate NEFTune."
)
model
,
handle
=
activate_neftune
(
model
)
num_steps_per_epoch
=
len
(
dataloader
)
//
args
.
accumulation_steps
# If resume training, set the sampler start index to the correct value
assert
isinstance
(
dataloader
.
sampler
,
StatefulDistributedSampler
)
dataloader
.
sampler
.
set_start_index
(
start_index
=
sampler_start_idx
)
for
epoch
in
range
(
start_epoch
,
args
.
num_epochs
):
dataloader
.
sampler
.
set_epoch
(
epoch
=
epoch
)
pbar
=
tqdm
(
desc
=
f
"Epoch
{
epoch
}
"
,
disable
=
not
coordinator
.
is_master
(),
total
=
num_steps_per_epoch
)
total_loss
=
torch
.
tensor
(
0.0
).
to
(
torch
.
cuda
.
current_device
())
for
step
,
batch
in
enumerate
(
dataloader
):
batch
=
{
k
:
v
.
to
(
get_current_device
())
for
k
,
v
in
batch
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)}
batch_output
=
model
(
**
batch
)
loss
=
batch_output
.
loss
/
args
.
accumulation_steps
total_loss
+=
loss
.
item
()
booster
.
backward
(
loss
=
loss
,
optimizer
=
optimizer
)
if
(
step
+
1
)
%
args
.
accumulation_steps
==
0
:
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
all_reduce_mean
(
tensor
=
total_loss
)
pbar
.
set_postfix
({
"Loss"
:
f
"
{
total_loss
.
item
():.
4
f
}
"
})
if
coordinator
.
is_master
():
global_step
=
(
epoch
*
num_steps_per_epoch
)
+
(
step
+
1
)
//
args
.
accumulation_steps
writer
.
add_scalar
(
tag
=
"Loss"
,
scalar_value
=
total_loss
.
item
(),
global_step
=
global_step
)
writer
.
add_scalar
(
tag
=
"Learning Rate"
,
scalar_value
=
lr_scheduler
.
get_last_lr
()[
0
],
global_step
=
global_step
,
)
total_loss
.
fill_
(
0.0
)
pbar
.
update
()
# Save modeling.
if
(
args
.
save_interval
>
0
and
(
step
+
1
)
%
(
args
.
save_interval
*
args
.
accumulation_steps
)
==
0
)
or
(
step
+
1
)
==
len
(
dataloader
):
coordinator
.
print_on_master
(
"
\n
Start saving model checkpoint with running states"
)
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Deactivate NEFTune before saving model."
)
deactivate_neftune
(
model
,
handle
)
save_checkpoint
(
save_dir
=
args
.
save_dir
,
booster
=
booster
,
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
epoch
=
epoch
,
step
=
step
+
1
,
batch_size
=
args
.
micro_batch_size
,
coordinator
=
coordinator
,
)
coordinator
.
print_on_master
(
f
"Saved checkpoint at epoch
{
epoch
}
step
{
step
+
1
}
at folder
{
args
.
save_dir
}
"
)
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Activate NEFTune."
)
model
,
handle
=
activate_neftune
(
model
)
# Delete CUDA cache.
# del batch, batch_labels, batch_output, loss
torch
.
cuda
.
empty_cache
()
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader
.
sampler
.
set_start_index
(
start_index
=
0
)
start_step
=
0
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Deactivate NEFTune."
)
deactivate_neftune
(
model
,
handle
)
# Final save.
coordinator
.
print_on_master
(
"Start saving final model checkpoint"
)
booster
.
save_model
(
model
,
os
.
path
.
join
(
args
.
save_dir
,
"modeling"
),
shard
=
True
)
coordinator
.
print_on_master
(
f
"Saved final model checkpoint at epoch
{
epoch
}
at folder
{
args
.
save_dir
}
"
)
coordinator
.
print_on_master
(
f
"Max CUDA memory usage:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
if
__name__
==
"__main__"
:
main
()
applications/ColossalEval/colossal_eval/models/chatglm.py
View file @
4c03347f
...
@@ -3,6 +3,8 @@ from typing import List
...
@@ -3,6 +3,8 @@ from typing import List
import
torch
import
torch
from
colossalai.utils
import
get_current_device
from
.huggingface
import
HuggingFaceModel
from
.huggingface
import
HuggingFaceModel
IGNORE_INDEX
=
-
100
IGNORE_INDEX
=
-
100
...
@@ -126,9 +128,9 @@ class ChatGLMModel(HuggingFaceModel):
...
@@ -126,9 +128,9 @@ class ChatGLMModel(HuggingFaceModel):
"""
"""
input_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
input_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
input_ids_list
,
batch_first
=
True
,
padding_value
=
self
.
tokenizer
.
pad_token_id
input_ids_list
,
batch_first
=
True
,
padding_value
=
self
.
tokenizer
.
pad_token_id
).
to
(
torch
.
cuda
.
current_device
())
).
to
(
get_
current_device
())
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=
IGNORE_INDEX
).
to
(
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=
IGNORE_INDEX
).
to
(
torch
.
cuda
.
current_device
()
get_
current_device
()
)
)
outputs
=
self
.
model
(
input_ids
)[
0
]
outputs
=
self
.
model
(
input_ids
)[
0
]
...
@@ -197,7 +199,7 @@ class ChatGLM2Model(ChatGLMModel):
...
@@ -197,7 +199,7 @@ class ChatGLM2Model(ChatGLMModel):
truncation
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
max_length
=
self
.
model_max_length
-
max_new_tokens
,
max_length
=
self
.
model_max_length
-
max_new_tokens
,
).
to
(
torch
.
cuda
.
current_device
())
).
to
(
get_
current_device
())
# Set output_scores=True to get prediction scores.
# Set output_scores=True to get prediction scores.
outputs
=
self
.
model
.
generate
(
outputs
=
self
.
model
.
generate
(
...
...
applications/ColossalEval/colossal_eval/models/huggingface.py
View file @
4c03347f
...
@@ -11,6 +11,7 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokeni
...
@@ -11,6 +11,7 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokeni
from
colossalai.logging
import
DistributedLogger
from
colossalai.logging
import
DistributedLogger
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
colossalai.utils
import
get_current_device
from
.base
import
BaseModel
from
.base
import
BaseModel
...
@@ -128,12 +129,12 @@ class HuggingFaceModel(BaseModel):
...
@@ -128,12 +129,12 @@ class HuggingFaceModel(BaseModel):
self
.
model
=
AutoModel
.
from_pretrained
(
path
,
**
model_kwargs
)
self
.
model
=
AutoModel
.
from_pretrained
(
path
,
**
model_kwargs
)
shard_former
=
ShardFormer
(
shard_config
)
shard_former
=
ShardFormer
(
shard_config
)
self
.
model
,
sharded_parameters
=
shard_former
.
optimize
(
self
.
model
)
self
.
model
,
sharded_parameters
=
shard_former
.
optimize
(
self
.
model
)
self
.
model
.
to
(
torch
.
cuda
.
current_device
())
self
.
model
.
to
(
get_
current_device
())
if
peft_path
is
not
None
:
if
peft_path
is
not
None
:
raise
NotImplementedError
(
"ShardFormer for PEFT models is not implemented."
)
raise
NotImplementedError
(
"ShardFormer for PEFT models is not implemented."
)
else
:
else
:
self
.
model
=
AutoModel
.
from_pretrained
(
path
,
**
model_kwargs
).
to
(
torch
.
cuda
.
current_device
())
self
.
model
=
AutoModel
.
from_pretrained
(
path
,
**
model_kwargs
).
to
(
get_
current_device
())
if
peft_path
is
not
None
:
if
peft_path
is
not
None
:
self
.
model
=
PeftModel
.
from_pretrained
(
self
.
model
,
peft_path
,
is_trainable
=
False
)
self
.
model
=
PeftModel
.
from_pretrained
(
self
.
model
,
peft_path
,
is_trainable
=
False
)
self
.
model
.
eval
()
self
.
model
.
eval
()
...
@@ -155,11 +156,11 @@ class HuggingFaceModel(BaseModel):
...
@@ -155,11 +156,11 @@ class HuggingFaceModel(BaseModel):
"""
"""
input_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
input_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
input_ids_list
,
batch_first
=
True
,
padding_value
=
self
.
tokenizer
.
pad_token_id
input_ids_list
,
batch_first
=
True
,
padding_value
=
self
.
tokenizer
.
pad_token_id
).
to
(
torch
.
cuda
.
current_device
())
).
to
(
get_
current_device
())
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=
IGNORE_INDEX
).
to
(
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=
IGNORE_INDEX
).
to
(
torch
.
cuda
.
current_device
()
get_
current_device
()
)
)
attention_mask
=
input_ids
.
ne
(
self
.
tokenizer
.
pad_token_id
).
to
(
torch
.
cuda
.
current_device
())
attention_mask
=
input_ids
.
ne
(
self
.
tokenizer
.
pad_token_id
).
to
(
get_
current_device
())
outputs
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
)[
0
]
outputs
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
)[
0
]
...
@@ -464,7 +465,7 @@ class HuggingFaceModel(BaseModel):
...
@@ -464,7 +465,7 @@ class HuggingFaceModel(BaseModel):
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
return_token_type_ids
=
False
,
return_token_type_ids
=
False
,
max_length
=
self
.
model_max_length
-
max_new_tokens
,
max_length
=
self
.
model_max_length
-
max_new_tokens
,
).
to
(
torch
.
cuda
.
current_device
())
).
to
(
get_
current_device
())
# Set output_scores=True to get prediction scores.
# Set output_scores=True to get prediction scores.
outputs
=
self
.
model
.
generate
(
outputs
=
self
.
model
.
generate
(
...
@@ -598,12 +599,12 @@ class HuggingFaceCausalLM(HuggingFaceModel):
...
@@ -598,12 +599,12 @@ class HuggingFaceCausalLM(HuggingFaceModel):
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
path
,
**
model_kwargs
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
path
,
**
model_kwargs
)
shard_former
=
ShardFormer
(
shard_config
)
shard_former
=
ShardFormer
(
shard_config
)
self
.
model
,
sharded_parameters
=
shard_former
.
optimize
(
self
.
model
)
self
.
model
,
sharded_parameters
=
shard_former
.
optimize
(
self
.
model
)
self
.
model
.
to
(
torch
.
cuda
.
current_device
())
self
.
model
.
to
(
get_
current_device
())
if
peft_path
is
not
None
:
if
peft_path
is
not
None
:
raise
NotImplementedError
(
"ShardFormer for PEFT models is not implemented."
)
raise
NotImplementedError
(
"ShardFormer for PEFT models is not implemented."
)
else
:
else
:
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
path
,
**
model_kwargs
).
to
(
torch
.
cuda
.
current_device
())
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
path
,
**
model_kwargs
).
to
(
get_
current_device
())
if
peft_path
is
not
None
:
if
peft_path
is
not
None
:
self
.
model
=
PeftModel
.
from_pretrained
(
self
.
model
,
peft_path
,
is_trainable
=
False
)
self
.
model
=
PeftModel
.
from_pretrained
(
self
.
model
,
peft_path
,
is_trainable
=
False
)
...
...
applications/ColossalEval/examples/dataset_evaluation/inference.py
View file @
4c03347f
...
@@ -8,6 +8,7 @@ import torch.distributed as dist
...
@@ -8,6 +8,7 @@ import torch.distributed as dist
from
colossal_eval
import
dataset
,
models
,
utils
from
colossal_eval
import
dataset
,
models
,
utils
import
colossalai
import
colossalai
from
colossalai.accelerator
import
get_accelerator
from
colossalai.cluster
import
ProcessGroupMesh
from
colossalai.cluster
import
ProcessGroupMesh
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.shardformer
import
ShardConfig
from
colossalai.shardformer
import
ShardConfig
...
@@ -82,6 +83,7 @@ def rm_and_merge(
...
@@ -82,6 +83,7 @@ def rm_and_merge(
def
main
(
args
):
def
main
(
args
):
colossalai
.
launch_from_torch
(
config
=
{},
seed
=
42
)
colossalai
.
launch_from_torch
(
config
=
{},
seed
=
42
)
accelerator
=
get_accelerator
()
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
...
@@ -235,10 +237,10 @@ def main(args):
...
@@ -235,10 +237,10 @@ def main(args):
),
),
)
)
logger
.
info
(
f
"Rank
{
rank
}
peak
CUDA
mem:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
3
:.
3
f
}
GB"
)
logger
.
info
(
f
"Rank
{
rank
}
peak
device
mem:
{
accelerator
.
max_memory_allocated
()
/
1024
**
3
:.
3
f
}
GB"
)
del
model_
del
model_
torch
.
cuda
.
empty_cache
()
accelerator
.
empty_cache
()
dist
.
barrier
()
dist
.
barrier
()
if
rank
==
0
:
if
rank
==
0
:
...
...
colossalai/booster/plugin/dp_plugin_base.py
View file @
4c03347f
...
@@ -21,7 +21,16 @@ class DPPluginBase(Plugin):
...
@@ -21,7 +21,16 @@ class DPPluginBase(Plugin):
self
.
world_size
=
dist
.
get_world_size
()
self
.
world_size
=
dist
.
get_world_size
()
def
prepare_dataloader
(
def
prepare_dataloader
(
self
,
dataset
,
batch_size
,
shuffle
=
False
,
seed
=
1024
,
drop_last
=
False
,
pin_memory
=
False
,
num_workers
=
0
,
**
kwargs
self
,
dataset
,
batch_size
,
shuffle
=
False
,
seed
=
1024
,
drop_last
=
False
,
pin_memory
=
False
,
num_workers
=
0
,
distributed_sampler_cls
=
None
,
**
kwargs
,
):
):
r
"""
r
"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
Prepare a dataloader for distributed training. The dataloader will be wrapped by
...
@@ -45,7 +54,8 @@ class DPPluginBase(Plugin):
...
@@ -45,7 +54,8 @@ class DPPluginBase(Plugin):
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
"""
_kwargs
=
kwargs
.
copy
()
_kwargs
=
kwargs
.
copy
()
sampler
=
DistributedSampler
(
dataset
,
num_replicas
=
self
.
world_size
,
rank
=
self
.
rank
,
shuffle
=
shuffle
)
distributed_sampler_cls
=
distributed_sampler_cls
or
DistributedSampler
sampler
=
distributed_sampler_cls
(
dataset
,
num_replicas
=
self
.
world_size
,
rank
=
self
.
rank
,
shuffle
=
shuffle
)
# Deterministic dataloader
# Deterministic dataloader
def
seed_worker
(
worker_id
):
def
seed_worker
(
worker_id
):
...
...
colossalai/booster/plugin/gemini_plugin.py
View file @
4c03347f
...
@@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase):
...
@@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase):
return
[
"cuda"
,
"npu"
]
return
[
"cuda"
,
"npu"
]
def
prepare_dataloader
(
def
prepare_dataloader
(
self
,
dataset
,
batch_size
,
shuffle
=
False
,
seed
=
1024
,
drop_last
=
False
,
pin_memory
=
False
,
num_workers
=
0
,
**
kwargs
self
,
dataset
,
batch_size
,
shuffle
=
False
,
seed
=
1024
,
drop_last
=
False
,
pin_memory
=
False
,
num_workers
=
0
,
distributed_sampler_cls
=
None
,
**
kwargs
,
):
):
r
"""
r
"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
Prepare a dataloader for distributed training. The dataloader will be wrapped by
...
@@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase):
...
@@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase):
extra_dp_world_size
=
self
.
pg_mesh
.
size
(
DP_AXIS
)
extra_dp_world_size
=
self
.
pg_mesh
.
size
(
DP_AXIS
)
zero_rank
=
self
.
pg_mesh
.
coordinate
(
ZERO_AXIS
)
zero_rank
=
self
.
pg_mesh
.
coordinate
(
ZERO_AXIS
)
extra_dp_rank
=
self
.
pg_mesh
.
coordinate
(
DP_AXIS
)
extra_dp_rank
=
self
.
pg_mesh
.
coordinate
(
DP_AXIS
)
sampler
=
DistributedSampler
(
distributed_sampler_cls
=
distributed_sampler_cls
or
DistributedSampler
sampler
=
distributed_sampler_cls
(
dataset
,
dataset
,
num_replicas
=
zero_world_size
*
extra_dp_world_size
,
num_replicas
=
zero_world_size
*
extra_dp_world_size
,
rank
=
zero_rank
*
extra_dp_world_size
+
extra_dp_rank
,
rank
=
zero_rank
*
extra_dp_world_size
+
extra_dp_rank
,
...
...
colossalai/booster/plugin/hybrid_parallel_plugin.py
View file @
4c03347f
...
@@ -1205,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -1205,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase):
return
outputs
return
outputs
def
prepare_dataloader
(
def
prepare_dataloader
(
self
,
dataset
,
batch_size
,
shuffle
=
False
,
seed
=
1024
,
drop_last
=
False
,
pin_memory
=
False
,
num_workers
=
0
,
**
kwargs
self
,
dataset
,
batch_size
,
shuffle
=
False
,
seed
=
1024
,
drop_last
=
False
,
pin_memory
=
False
,
num_workers
=
0
,
distributed_sampler_cls
=
None
,
**
kwargs
,
):
):
r
"""
r
"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
Prepare a dataloader for distributed training. The dataloader will be wrapped by
...
@@ -1229,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -1229,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase):
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
"""
_kwargs
=
kwargs
.
copy
()
_kwargs
=
kwargs
.
copy
()
sampler
=
DistributedSampler
(
distributed_sampler_cls
=
distributed_sampler_cls
or
DistributedSampler
sampler
=
distributed_sampler_cls
(
dataset
,
num_replicas
=
self
.
pg_mesh
.
size
(
DP_AXIS
),
rank
=
self
.
pg_mesh
.
coordinate
(
DP_AXIS
),
shuffle
=
shuffle
dataset
,
num_replicas
=
self
.
pg_mesh
.
size
(
DP_AXIS
),
rank
=
self
.
pg_mesh
.
coordinate
(
DP_AXIS
),
shuffle
=
shuffle
)
)
...
...
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
View file @
4c03347f
...
@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
...
@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from
colossalai.cluster
import
DistCoordinator
from
colossalai.cluster
import
DistCoordinator
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.utils
import
get_current_device
from
.general_checkpoint_io
import
GeneralCheckpointIO
from
.general_checkpoint_io
import
GeneralCheckpointIO
from
.index_file
import
CheckpointIndexFile
from
.index_file
import
CheckpointIndexFile
...
@@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
...
@@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
tp_group
=
self
.
tp_group
,
tp_group
=
self
.
tp_group
,
use_zero
=
self
.
use_zero
,
use_zero
=
self
.
use_zero
,
inplace
=
False
,
inplace
=
False
,
device
=
torch
.
device
(
"cuda"
),
device
=
get_current_
device
(),
)
)
if
self
.
pp_size
==
1
:
if
self
.
pp_size
==
1
:
...
@@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
...
@@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if
isinstance
(
v
,
torch
.
Tensor
)
and
k
!=
"step"
:
if
isinstance
(
v
,
torch
.
Tensor
)
and
k
!=
"step"
:
# First gather Zero shards.
# First gather Zero shards.
if
use_zero
:
if
use_zero
:
v
=
v
.
cuda
()
v
=
v
.
to
(
get_current_device
()
)
gather_tensor
=
[
torch
.
zeros_like
(
v
)
for
_
in
range
(
dp_size
)]
gather_tensor
=
[
torch
.
zeros_like
(
v
)
for
_
in
range
(
dp_size
)]
dist
.
all_gather
(
gather_tensor
,
v
,
group
=
dp_group
)
dist
.
all_gather
(
gather_tensor
,
v
,
group
=
dp_group
)
v
=
torch
.
stack
(
gather_tensor
).
view
(
-
1
)[:
param
.
numel
()].
reshape_as
(
param
)
v
=
torch
.
stack
(
gather_tensor
).
view
(
-
1
)[:
param
.
numel
()].
reshape_as
(
param
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment