Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5942af97
Commit
5942af97
authored
Sep 23, 2022
by
Jared Casper
Browse files
Alias core.parallel_state as mpu and use it throughout code. RIP mpu.
parent
c2ea914f
Changes
63
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
148 additions
and
152 deletions
+148
-152
megatron/checkpointing.py
megatron/checkpointing.py
+19
-19
megatron/core/__init__.py
megatron/core/__init__.py
+12
-13
megatron/core/parallel_state.py
megatron/core/parallel_state.py
+0
-4
megatron/core/tensor_parallel/__init__.py
megatron/core/tensor_parallel/__init__.py
+7
-2
megatron/core/tensor_parallel/random.py
megatron/core/tensor_parallel/random.py
+4
-32
megatron/core/tensor_parallel/utils.py
megatron/core/tensor_parallel/utils.py
+66
-9
megatron/core/utils.py
megatron/core/utils.py
+0
-29
megatron/data/biencoder_dataset_utils.py
megatron/data/biencoder_dataset_utils.py
+3
-2
megatron/data/blendable_dataset.py
megatron/data/blendable_dataset.py
+0
-2
megatron/data/data_samplers.py
megatron/data/data_samplers.py
+1
-1
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+1
-1
megatron/data/gpt_dataset.py
megatron/data/gpt_dataset.py
+2
-1
megatron/data/orqa_wiki_dataset.py
megatron/data/orqa_wiki_dataset.py
+3
-2
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+4
-3
megatron/data/realm_index.py
megatron/data/realm_index.py
+7
-7
megatron/indexer.py
megatron/indexer.py
+1
-1
megatron/initialize.py
megatron/initialize.py
+12
-17
megatron/model/bert_model.py
megatron/model/bert_model.py
+4
-4
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+2
-2
megatron/model/classification.py
megatron/model/classification.py
+0
-1
No files found.
megatron/checkpointing.py
View file @
5942af97
...
@@ -9,8 +9,8 @@ import numpy as np
...
@@ -9,8 +9,8 @@ import numpy as np
import
torch
import
torch
from
megatron
import
(
core
,
from
megatron
import
update_num_microbatches
update_num_microbatches
)
from
megatron.core
import
mpu
,
tensor_parallel
from
.global_vars
import
get_args
from
.global_vars
import
get_args
from
.utils
import
(
unwrap_model
,
from
.utils
import
(
unwrap_model
,
print_rank_0
)
print_rank_0
)
...
@@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
...
@@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
# Use both the tensor and pipeline MP rank.
# Use both the tensor and pipeline MP rank.
if
pipeline_parallel
is
None
:
if
pipeline_parallel
is
None
:
pipeline_parallel
=
(
core
.
get_pipeline_model_parallel_world_size
()
>
1
)
pipeline_parallel
=
(
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
)
if
tensor_rank
is
None
:
if
tensor_rank
is
None
:
tensor_rank
=
core
.
get_tensor_model_parallel_rank
()
tensor_rank
=
mpu
.
get_tensor_model_parallel_rank
()
if
pipeline_rank
is
None
:
if
pipeline_rank
is
None
:
pipeline_rank
=
core
.
get_pipeline_model_parallel_rank
()
pipeline_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
# Use both the tensor and pipeline MP rank. If using the distributed
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
# optimizer, then the optimizer's path must additionally include the
...
@@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
...
@@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
if
use_distributed_optimizer
:
if
use_distributed_optimizer
:
model_name
=
os
.
path
.
join
(
common_path
,
"model_rng.pt"
)
model_name
=
os
.
path
.
join
(
common_path
,
"model_rng.pt"
)
optim_name
=
os
.
path
.
join
(
optim_name
=
os
.
path
.
join
(
common_path
+
"_%03d"
%
core
.
get_data_parallel_rank
(),
common_path
+
"_%03d"
%
mpu
.
get_data_parallel_rank
(),
"optim.pt"
)
"optim.pt"
)
else
:
else
:
model_name
=
optim_name
=
os
.
path
.
join
(
common_path
,
"model_optim_rng.pt"
)
model_name
=
optim_name
=
os
.
path
.
join
(
common_path
,
"model_optim_rng.pt"
)
...
@@ -185,18 +185,18 @@ def get_rng_state():
...
@@ -185,18 +185,18 @@ def get_rng_state():
'np_rng_state'
:
np
.
random
.
get_state
(),
'np_rng_state'
:
np
.
random
.
get_state
(),
'torch_rng_state'
:
torch
.
get_rng_state
(),
'torch_rng_state'
:
torch
.
get_rng_state
(),
'cuda_rng_state'
:
torch
.
cuda
.
get_rng_state
(),
'cuda_rng_state'
:
torch
.
cuda
.
get_rng_state
(),
'rng_tracker_states'
:
core
.
tensor_parallel
.
get_cuda_rng_tracker
().
get_states
()}
'rng_tracker_states'
:
tensor_parallel
.
get_cuda_rng_tracker
().
get_states
()}
rng_state_list
=
None
rng_state_list
=
None
if
torch
.
distributed
.
is_initialized
()
and
\
if
torch
.
distributed
.
is_initialized
()
and
\
core
.
get_data_parallel_world_size
()
>
1
and
\
mpu
.
get_data_parallel_world_size
()
>
1
and
\
args
.
data_parallel_random_init
:
args
.
data_parallel_random_init
:
rng_state_list
=
\
rng_state_list
=
\
[
None
for
i
in
range
(
core
.
get_data_parallel_world_size
())]
[
None
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
torch
.
distributed
.
all_gather_object
(
torch
.
distributed
.
all_gather_object
(
rng_state_list
,
rng_state_list
,
rng_state
,
rng_state
,
group
=
core
.
get_data_parallel_group
())
group
=
mpu
.
get_data_parallel_group
())
else
:
else
:
rng_state_list
=
[
rng_state
]
rng_state_list
=
[
rng_state
]
...
@@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
...
@@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Collect args, model, RNG.
# Collect args, model, RNG.
model_state_dict
=
{}
model_state_dict
=
{}
if
not
torch
.
distributed
.
is_initialized
()
\
if
not
torch
.
distributed
.
is_initialized
()
\
or
core
.
get_data_parallel_rank
()
==
0
:
or
mpu
.
get_data_parallel_rank
()
==
0
:
# Arguments, iteration, and model.
# Arguments, iteration, and model.
model_state_dict
[
'args'
]
=
args
model_state_dict
[
'args'
]
=
args
...
@@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
...
@@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_state_dict
[
'model'
]
=
model
[
0
].
state_dict_for_save_checkpoint
()
model_state_dict
[
'model'
]
=
model
[
0
].
state_dict_for_save_checkpoint
()
else
:
else
:
for
i
in
range
(
len
(
model
)):
for
i
in
range
(
len
(
model
)):
core
.
set_virtual_pipeline_model_parallel_rank
(
i
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model_state_dict
[
'model%d'
%
i
]
=
\
model_state_dict
[
'model%d'
%
i
]
=
\
model
[
i
].
state_dict_for_save_checkpoint
()
model
[
i
].
state_dict_for_save_checkpoint
()
...
@@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
...
@@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
optim_state_dict
=
{}
optim_state_dict
=
{}
if
not
args
.
no_save_optim
\
if
not
args
.
no_save_optim
\
and
(
not
torch
.
distributed
.
is_initialized
()
and
(
not
torch
.
distributed
.
is_initialized
()
or
core
.
get_data_parallel_rank
()
==
0
or
mpu
.
get_data_parallel_rank
()
==
0
or
args
.
use_distributed_optimizer
):
or
args
.
use_distributed_optimizer
):
# Optimizer stuff.
# Optimizer stuff.
...
@@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model
[
0
].
load_state_dict
(
model_state_dict
[
'model'
],
strict
=
strict
)
model
[
0
].
load_state_dict
(
model_state_dict
[
'model'
],
strict
=
strict
)
else
:
else
:
for
i
in
range
(
len
(
model
)):
for
i
in
range
(
len
(
model
)):
core
.
set_virtual_pipeline_model_parallel_rank
(
i
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
[
i
].
load_state_dict
(
model_state_dict
[
'model%d'
%
i
],
strict
=
strict
)
model
[
i
].
load_state_dict
(
model_state_dict
[
'model%d'
%
i
],
strict
=
strict
)
# Fix up query/key/value matrix ordering if needed
# Fix up query/key/value matrix ordering if needed
...
@@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# access rng_state for data parallel rank
# access rng_state for data parallel rank
if
args
.
data_parallel_random_init
:
if
args
.
data_parallel_random_init
:
rng_state
=
model_state_dict
[
'rng_state'
][
core
.
get_data_parallel_rank
()]
rng_state
=
model_state_dict
[
'rng_state'
][
mpu
.
get_data_parallel_rank
()]
else
:
else
:
rng_state
=
model_state_dict
[
'rng_state'
][
0
]
rng_state
=
model_state_dict
[
'rng_state'
][
0
]
random
.
setstate
(
rng_state
[
'random_rng_state'
])
random
.
setstate
(
rng_state
[
'random_rng_state'
])
...
@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
# Check for empty states array
if
not
rng_state
[
'rng_tracker_states'
]:
if
not
rng_state
[
'rng_tracker_states'
]:
raise
KeyError
raise
KeyError
core
.
tensor_parallel
.
get_cuda_rng_tracker
().
set_states
(
tensor_parallel
.
get_cuda_rng_tracker
().
set_states
(
rng_state
[
'rng_tracker_states'
])
rng_state
[
'rng_tracker_states'
])
else
:
# backward compatability
else
:
# backward compatability
random
.
setstate
(
model_state_dict
[
'random_rng_state'
])
random
.
setstate
(
model_state_dict
[
'random_rng_state'
])
...
@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
# Check for empty states array
if
not
model_state_dict
[
'rng_tracker_states'
]:
if
not
model_state_dict
[
'rng_tracker_states'
]:
raise
KeyError
raise
KeyError
core
.
tensor_parallel
.
get_cuda_rng_tracker
().
set_states
(
tensor_parallel
.
get_cuda_rng_tracker
().
set_states
(
model_state_dict
[
'rng_tracker_states'
])
model_state_dict
[
'rng_tracker_states'
])
except
KeyError
:
except
KeyError
:
print_rank_0
(
'Unable to load rng state from checkpoint {}. '
print_rank_0
(
'Unable to load rng state from checkpoint {}. '
...
@@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
...
@@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
args
.
use_distributed_optimizer
,
args
.
use_distributed_optimizer
,
release
=
False
)
release
=
False
)
if
core
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
...
@@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
...
@@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
model
[
0
].
load_state_dict
(
ret_state_dict
)
model
[
0
].
load_state_dict
(
ret_state_dict
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
if
core
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
return
model
megatron/core/__init__.py
View file @
5942af97
from
.parallel_state
import
(
import
megatron.core.parallel_state
initialize_model_parallel
,
import
megatron.core.tensor_parallel
get_tensor_model_parallel_world_size
,
import
megatron.core.utils
get_tensor_model_parallel_rank
,
get_pipeline_model_parallel_world_size
,
# Alias parallel_state as mpu, its legacy name
get_pipeline_model_parallel_rank
,
mpu
=
parallel_state
get_virtual_pipeline_model_parallel_rank
,
set_virtual_pipeline_model_parallel_rank
,
get_data_parallel_world_size
,
__all__
=
[
get_data_parallel_rank
,
"parallel_state"
,
get_global_memory_buffer
,
"tensor_parallel"
,
get_num_layers
,
"utils"
,
)
]
from
megatron.core
import
tensor_parallel
megatron/core/parallel_state.py
View file @
5942af97
...
@@ -47,10 +47,6 @@ _DATA_PARALLEL_GLOBAL_RANKS = None
...
@@ -47,10 +47,6 @@ _DATA_PARALLEL_GLOBAL_RANKS = None
# Memory buffers to avoid dynamic memory allocation
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER
=
None
_GLOBAL_MEMORY_BUFFER
=
None
def
is_unitialized
():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return
_DATA_PARALLEL_GROUP
is
None
def
initialize_model_parallel
(
def
initialize_model_parallel
(
tensor_model_parallel_size
:
int
=
1
,
tensor_model_parallel_size
:
int
=
1
,
...
...
megatron/core/tensor_parallel/__init__.py
View file @
5942af97
...
@@ -5,6 +5,7 @@ from .layers import (
...
@@ -5,6 +5,7 @@ from .layers import (
ColumnParallelLinear
,
ColumnParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
set_tensor_model_parallel_attributes
,
set_defaults_if_not_set_tensor_model_parallel_attributes
,
set_defaults_if_not_set_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
,
param_is_not_tensor_parallel_duplicate
,
param_is_not_tensor_parallel_duplicate
,
...
@@ -23,10 +24,14 @@ from .mappings import (
...
@@ -23,10 +24,14 @@ from .mappings import (
from
.random
import
(
from
.random
import
(
checkpoint
,
checkpoint
,
get_cuda_rng_tracker
,
get_cuda_rng_tracker
,
model_parallel_cuda_manual_seed
model_parallel_cuda_manual_seed
,
)
)
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
(
split_tensor_along_last_dim
,
split_tensor_into_1d_equal_chunks
,
gather_split_1d_tensor
,
)
__all__
=
[
__all__
=
[
# cross_entropy.py
# cross_entropy.py
...
...
megatron/core/tensor_parallel/random.py
View file @
5942af97
...
@@ -17,6 +17,10 @@ from megatron.core.parallel_state import (
...
@@ -17,6 +17,10 @@ from megatron.core.parallel_state import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
.utils
import
(
split_tensor_into_1d_equal_chunks
,
gather_split_1d_tensor
,
)
# Default name for the model parallel rng tracker.
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
...
@@ -55,38 +59,6 @@ def _set_cuda_rng_state(new_state, device=-1):
...
@@ -55,38 +59,6 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call
(
cb
)
_lazy_call
(
cb
)
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
"""Break a tensor into equal 1D chunks."""
partition_size
=
torch
.
numel
(
tensor
)
//
\
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
else
:
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
return
data
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
numel_gathered
=
torch
.
numel
(
tensor
)
*
\
get_tensor_model_parallel_world_size
()
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch
.
distributed
.
_all_gather_base
(
gathered
,
tensor
,
group
=
get_tensor_model_parallel_group
())
return
gathered
class
CudaRNGStatesTracker
:
class
CudaRNGStatesTracker
:
"""Tracker for the cuda RNG states.
"""Tracker for the cuda RNG states.
...
...
megatron/core/tensor_parallel/utils.py
View file @
5942af97
...
@@ -10,12 +10,16 @@ def split_tensor_along_last_dim(
...
@@ -10,12 +10,16 @@ def split_tensor_along_last_dim(
num_partitions
:
int
,
num_partitions
:
int
,
contiguous_split_chunks
:
bool
=
False
,
contiguous_split_chunks
:
bool
=
False
,
)
->
List
[
torch
.
Tensor
]:
)
->
List
[
torch
.
Tensor
]:
"""Split a tensor along its last dimension.
""" Split a tensor along its last dimension.
Arguments:
Arguments:
tensor: input tensor.
tensor: input tensor.
num_partitions: number of partitions to split the tensor
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
in memory.
Returns:
A list of Tensors
"""
"""
# Get the size and dimension.
# Get the size and dimension.
last_dim
=
tensor
.
dim
()
-
1
last_dim
=
tensor
.
dim
()
-
1
...
@@ -28,11 +32,64 @@ def split_tensor_along_last_dim(
...
@@ -28,11 +32,64 @@ def split_tensor_along_last_dim(
return
tensor_list
return
tensor_list
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
Returns a Tensor or View with this rank's portion of the data.
Arguments:
tensor: The tensor to split
Keyword Arguments:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
"""
partition_size
=
torch
.
numel
(
tensor
)
//
\
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
else
:
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
return
data
def
gather_split_1d_tensor
(
tensor
):
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
Returns a new Tensor with the gathered data.
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered
=
torch
.
numel
(
tensor
)
*
\
get_tensor_model_parallel_world_size
()
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch
.
distributed
.
_all_gather_base
(
gathered
,
tensor
,
group
=
get_tensor_model_parallel_group
())
return
gathered
class
VocabUtility
:
class
VocabUtility
:
"""Split the vocabulary into `world_size` chunks and return the
""" Split the vocabulary into `world_size` chunks and return the first
first and last index of the vocabulary belonging to the `rank`
and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)"""
partition: Note that indices in [fist, last)
"""
@
staticmethod
@
staticmethod
def
vocab_range_from_per_partition_vocab_size
(
def
vocab_range_from_per_partition_vocab_size
(
...
...
megatron/core/utils.py
View file @
5942af97
...
@@ -21,35 +21,6 @@ def divide(numerator, denominator):
...
@@ -21,35 +21,6 @@ def divide(numerator, denominator):
return
numerator
//
denominator
return
numerator
//
denominator
def
split_tensor_into_1d_equal_chunks
(
tensor
):
"""Break a tensor into equal 1D chunks."""
data
=
tensor
.
view
(
-
1
)
partition_size
=
(
torch
.
numel
(
data
)
//
parallel_state
.
get_tensor_model_parallel_world_size
()
)
start_index
=
partition_size
*
parallel_state
.
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
return
data
[
start_index
:
end_index
]
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
world_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
numel
=
torch
.
numel
(
tensor
)
numel_gathered
=
world_size
*
numel
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
torch
.
distributed
.
_all_gather_base
(
gathered
,
tensor
,
group
=
parallel_state
.
get_tensor_model_parallel_group
()
)
return
gathered
class
GlobalMemoryBuffer
:
class
GlobalMemoryBuffer
:
"""Global buffer to avoid dynamic memory allocations.
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
Caller should ensure that buffers of the same name
...
...
megatron/data/biencoder_dataset_utils.py
View file @
5942af97
...
@@ -4,7 +4,8 @@ import time
...
@@ -4,7 +4,8 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
megatron
import
get_args
,
get_tokenizer
,
mpu
,
print_rank_0
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
\
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
\
pad_and_convert_to_numpy
pad_and_convert_to_numpy
from
megatron.data.data_samplers
import
MegatronPretrainingSampler
from
megatron.data.data_samplers
import
MegatronPretrainingSampler
...
@@ -57,7 +58,7 @@ def get_ict_batch(data_iterator):
...
@@ -57,7 +58,7 @@ def get_ict_batch(data_iterator):
data
=
None
data
=
None
else
:
else
:
data
=
next
(
data_iterator
)
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
tensor_parallel
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_tokens
=
data_b
[
'query_tokens'
].
long
()
...
...
megatron/data/blendable_dataset.py
View file @
5942af97
...
@@ -8,8 +8,6 @@ import numpy as np
...
@@ -8,8 +8,6 @@ import numpy as np
import
torch
import
torch
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
mpu
class
BlendableDataset
(
torch
.
utils
.
data
.
Dataset
):
class
BlendableDataset
(
torch
.
utils
.
data
.
Dataset
):
...
...
megatron/data/data_samplers.py
View file @
5942af97
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
import
numpy
as
np
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
.core
import
mpu
def
build_pretraining_data_loader
(
dataset
,
consumed_samples
):
def
build_pretraining_data_loader
(
dataset
,
consumed_samples
):
...
...
megatron/data/dataset_utils.py
View file @
5942af97
...
@@ -28,9 +28,9 @@ import torch
...
@@ -28,9 +28,9 @@ import torch
from
megatron
import
(
from
megatron
import
(
get_args
,
get_args
,
mpu
,
print_rank_0
print_rank_0
)
)
from
megatron.core
import
mpu
from
megatron.data.blendable_dataset
import
BlendableDataset
from
megatron.data.blendable_dataset
import
BlendableDataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
...
...
megatron/data/gpt_dataset.py
View file @
5942af97
...
@@ -8,7 +8,8 @@ import time
...
@@ -8,7 +8,8 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
megatron
import
mpu
,
print_rank_0
from
megatron
import
print_rank_0
from
megatron.core
import
mpu
from
megatron.data.blendable_dataset
import
BlendableDataset
from
megatron.data.blendable_dataset
import
BlendableDataset
from
megatron.data.dataset_utils
import
get_datasets_weights_and_num_samples
from
megatron.data.dataset_utils
import
get_datasets_weights_and_num_samples
from
megatron.data.dataset_utils
import
get_train_valid_test_split_
from
megatron.data.dataset_utils
import
get_train_valid_test_split_
...
...
megatron/data/orqa_wiki_dataset.py
View file @
5942af97
...
@@ -9,7 +9,8 @@ import random
...
@@ -9,7 +9,8 @@ import random
import
torch
import
torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
,
get_args
,
get_tokenizer
,
mpu
from
megatron
import
print_rank_0
,
get_args
,
get_tokenizer
from
megatron.core
import
tensor_parallel
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
def
get_open_retrieval_wiki_dataset
():
def
get_open_retrieval_wiki_dataset
():
...
@@ -32,7 +33,7 @@ def get_open_retrieval_batch(data_iterator):
...
@@ -32,7 +33,7 @@ def get_open_retrieval_batch(data_iterator):
# Broadcast data.
# Broadcast data.
data
=
None
if
data_iterator
is
None
else
next
(
data_iterator
)
data
=
None
if
data_iterator
is
None
else
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
tensor_parallel
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
# Unpack.
row_id
=
data_b
[
'row_id'
].
long
()
row_id
=
data_b
[
'row_id'
].
long
()
...
...
megatron/data/realm_dataset_utils.py
View file @
5942af97
...
@@ -4,9 +4,10 @@ import time
...
@@ -4,9 +4,10 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
megatron
import
mpu
,
print_rank_0
from
megatron
import
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
def
get_one_epoch_dataloader
(
dataset
,
micro_batch_size
=
None
):
def
get_one_epoch_dataloader
(
dataset
,
micro_batch_size
=
None
):
...
@@ -47,7 +48,7 @@ def get_ict_batch(data_iterator):
...
@@ -47,7 +48,7 @@ def get_ict_batch(data_iterator):
data
=
None
data
=
None
else
:
else
:
data
=
next
(
data_iterator
)
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
tensor_parallel
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_tokens
=
data_b
[
'query_tokens'
].
long
()
...
...
megatron/data/realm_index.py
View file @
5942af97
...
@@ -7,7 +7,7 @@ import numpy as np
...
@@ -7,7 +7,7 @@ import numpy as np
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
.core
import
mpu
def
detach
(
tensor
):
def
detach
(
tensor
):
...
@@ -50,10 +50,10 @@ class OpenRetreivalDataStore(object):
...
@@ -50,10 +50,10 @@ class OpenRetreivalDataStore(object):
def
load_from_file
(
self
):
def
load_from_file
(
self
):
"""Populate members from instance saved to file"""
"""Populate members from instance saved to file"""
if
mpu
.
is_
u
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
mpu
.
model_parallel_
is_
i
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
state_dict
=
pickle
.
load
(
open
(
self
.
embedding_path
,
'rb'
))
state_dict
=
pickle
.
load
(
open
(
self
.
embedding_path
,
'rb'
))
if
mpu
.
is_
u
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
mpu
.
model_parallel_
is_
i
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
self
.
embed_data
=
state_dict
[
'embed_data'
]
self
.
embed_data
=
state_dict
[
'embed_data'
]
...
@@ -137,7 +137,7 @@ class FaissMIPSIndex(object):
...
@@ -137,7 +137,7 @@ class FaissMIPSIndex(object):
except
ImportError
:
except
ImportError
:
raise
Exception
(
"Error: Please install faiss to use FaissMIPSIndex"
)
raise
Exception
(
"Error: Please install faiss to use FaissMIPSIndex"
)
if
mpu
.
is_
u
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
mpu
.
model_parallel_
is_
i
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Building index"
,
flush
=
True
)
print
(
"
\n
> Building index"
,
flush
=
True
)
cpu_index
=
faiss
.
IndexFlatIP
(
self
.
embed_size
)
cpu_index
=
faiss
.
IndexFlatIP
(
self
.
embed_size
)
...
@@ -149,12 +149,12 @@ class FaissMIPSIndex(object):
...
@@ -149,12 +149,12 @@ class FaissMIPSIndex(object):
config
.
useFloat16
=
True
config
.
useFloat16
=
True
gpu_index
=
faiss
.
index_cpu_to_all_gpus
(
cpu_index
,
co
=
config
)
gpu_index
=
faiss
.
index_cpu_to_all_gpus
(
cpu_index
,
co
=
config
)
self
.
mips_index
=
faiss
.
IndexIDMap
(
gpu_index
)
self
.
mips_index
=
faiss
.
IndexIDMap
(
gpu_index
)
if
mpu
.
is_
u
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
mpu
.
model_parallel_
is_
i
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Initialized index on GPU"
,
flush
=
True
)
print
(
">> Initialized index on GPU"
,
flush
=
True
)
else
:
else
:
# CPU index supports IDs so wrap with IDMap
# CPU index supports IDs so wrap with IDMap
self
.
mips_index
=
faiss
.
IndexIDMap
(
cpu_index
)
self
.
mips_index
=
faiss
.
IndexIDMap
(
cpu_index
)
if
mpu
.
is_
u
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
mpu
.
model_parallel_
is_
i
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Initialized index on CPU"
,
flush
=
True
)
print
(
">> Initialized index on CPU"
,
flush
=
True
)
# if we were constructed with a BlockData, then automatically load it
# if we were constructed with a BlockData, then automatically load it
...
@@ -199,7 +199,7 @@ class FaissMIPSIndex(object):
...
@@ -199,7 +199,7 @@ class FaissMIPSIndex(object):
self
.
mips_index
.
add_with_ids
(
embeds_arr
,
indices_arr
)
self
.
mips_index
.
add_with_ids
(
embeds_arr
,
indices_arr
)
if
mpu
.
is_
u
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
mpu
.
model_parallel_
is_
i
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">>> Finished adding block data to index"
,
flush
=
True
)
print
(
">>> Finished adding block data to index"
,
flush
=
True
)
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
...
...
megatron/indexer.py
View file @
5942af97
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
megatron.checkpointing
import
load_biencoder_checkpoint
from
megatron.checkpointing
import
load_biencoder_checkpoint
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_wiki_dataset
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_wiki_dataset
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_batch
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_batch
...
...
megatron/initialize.py
View file @
5942af97
...
@@ -14,13 +14,10 @@ from megatron import fused_kernels
...
@@ -14,13 +14,10 @@ from megatron import fused_kernels
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_tensorboard_writer
from
megatron
import
mpu
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron
import
core
from
megatron.arguments
import
(
parse_args
,
validate_args
)
from
megatron.arguments
import
(
parse_args
,
validate_args
)
from
megatron.checkpointing
import
load_args_from_checkpoint
from
megatron.checkpointing
import
load_args_from_checkpoint
from
megatron.global_vars
import
set_global_variables
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu
import
(
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
)
from
megatron.model.transformer
import
bias_dropout_add_fused_train
from
megatron.model.transformer
import
bias_dropout_add_fused_train
from
megatron.model.fused_bias_gelu
import
bias_gelu
from
megatron.model.fused_bias_gelu
import
bias_gelu
...
@@ -65,13 +62,14 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -65,13 +62,14 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args
=
get_args
()
args
=
get_args
()
if
args
.
lazy_mpu_init
:
if
args
.
lazy_mpu_init
:
# TODO is this still a necessary option?
args
.
use_cpu_initialization
=
True
args
.
use_cpu_initialization
=
True
# delayed initialization of DDP-related stuff
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
# We only set basic DDP globals
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
mpu
.
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
# and return function for external DDP manager
# and return function for external DDP manager
# to call when it has DDP initialized
# to call when it has DDP initialized
set_tensor_model_parallel_rank
(
args
.
rank
)
mpu
.
set_tensor_model_parallel_rank
(
args
.
rank
)
return
finish_mpu_init
return
finish_mpu_init
else
:
else
:
# Megatron's MPU is the master. Complete initialization right away.
# Megatron's MPU is the master. Complete initialization right away.
...
@@ -147,7 +145,7 @@ def _compile_dependencies():
...
@@ -147,7 +145,7 @@ def _compile_dependencies():
def
_initialize_distributed
():
def
_initialize_distributed
():
"""Initialize torch.distributed and
mpu
."""
"""Initialize torch.distributed and
core model parallel
."""
args
=
get_args
()
args
=
get_args
()
device_count
=
torch
.
cuda
.
device_count
()
device_count
=
torch
.
cuda
.
device_count
()
...
@@ -188,14 +186,11 @@ def _initialize_distributed():
...
@@ -188,14 +186,11 @@ def _initialize_distributed():
args
.
pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
)
args
.
pipeline_model_parallel_split_rank
)
core
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
if
args
.
rank
==
0
:
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
)
print
(
f
'> initialized tensor model parallel with size '
print
(
f
'> initialized tensor model parallel with size '
f
'
{
core
.
get_tensor_model_parallel_world_size
()
}
'
)
f
'
{
mpu
.
get_tensor_model_parallel_world_size
()
}
'
)
print
(
f
'> initialized pipeline model parallel with size '
print
(
f
'> initialized pipeline model parallel with size '
f
'
{
core
.
get_pipeline_model_parallel_world_size
()
}
'
)
f
'
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
'
)
def
_init_autoresume
():
def
_init_autoresume
():
...
@@ -219,7 +214,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
...
@@ -219,7 +214,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
device_count
()
>
0
:
if
torch
.
cuda
.
device_count
()
>
0
:
core
.
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
)
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
)
else
:
else
:
raise
ValueError
(
'Seed ({}) should be a positive integer.'
.
format
(
seed
))
raise
ValueError
(
'Seed ({}) should be a positive integer.'
.
format
(
seed
))
...
...
megatron/model/bert_model.py
View file @
5942af97
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
core
from
megatron
.core
import
tensor_parallel
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
...
@@ -61,7 +61,7 @@ class BertLMHead(MegatronModule):
...
@@ -61,7 +61,7 @@ class BertLMHead(MegatronModule):
args
=
get_args
()
args
=
get_args
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
mpu
.
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
1
)
tensor_parallel
.
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
1
)
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
...
@@ -110,9 +110,9 @@ def post_language_model_processing(lm_output, pooled_output,
...
@@ -110,9 +110,9 @@ def post_language_model_processing(lm_output, pooled_output,
# lm_logits : [s, b, h] and lm_labels: [s, b]
# lm_logits : [s, b, h] and lm_labels: [s, b]
if
fp16_lm_cross_entropy
:
if
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
core
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
else
:
lm_loss
=
core
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
lm_labels
)
# [s, b] => [b s]
# [s, b] => [b s]
lm_loss
=
lm_loss
.
transpose
(
0
,
1
).
contiguous
()
lm_loss
=
lm_loss
.
transpose
(
0
,
1
).
contiguous
()
...
...
megatron/model/biencoder_model.py
View file @
5942af97
...
@@ -2,11 +2,11 @@ import os
...
@@ -2,11 +2,11 @@ import os
import
torch
import
torch
import
sys
import
sys
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_args
,
print_rank_0
,
get_tokenizer
from
megatron.core
import
mpu
from
megatron.checkpointing
import
fix_query_key_value_ordering
from
megatron.checkpointing
import
fix_query_key_value_ordering
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
from
megatron.checkpointing
import
get_checkpoint_name
from
megatron.checkpointing
import
get_checkpoint_name
from
megatron
import
mpu
,
get_tokenizer
from
megatron.model.bert_model
import
bert_position_ids
from
megatron.model.bert_model
import
bert_position_ids
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
...
...
megatron/model/classification.py
View file @
5942af97
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
import
torch
import
torch
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment