Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
739cb43d
Commit
739cb43d
authored
May 24, 2022
by
Vijay Korthikanti
Browse files
resolved review comments
parent
9dc3c42a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
15 deletions
+24
-15
megatron/arguments.py
megatron/arguments.py
+7
-1
megatron/global_vars.py
megatron/global_vars.py
+13
-10
megatron/model/transformer.py
megatron/model/transformer.py
+2
-2
megatron/mpu/layers.py
megatron/mpu/layers.py
+2
-2
No files found.
megatron/arguments.py
View file @
739cb43d
...
@@ -304,7 +304,13 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -304,7 +304,13 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
recompute_method
is
None
,
\
assert
args
.
recompute_method
is
None
,
\
'recompute method is not yet supported for '
\
'recompute method is not yet supported for '
\
'selective recomputing granularity'
'selective recomputing granularity'
# disable sequence parallelism when tp=1
# to avoid change in numerics when
# sequence_parallelism is enabled.
if
args
.
tensor_model_parallel_size
==
1
:
args
.
sequence_parallel
=
False
# disable async_tensor_model_parallel_allreduce when
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
# model parallel memory optimization is enabled
if
args
.
sequence_parallel
:
if
args
.
sequence_parallel
:
...
...
megatron/global_vars.py
View file @
739cb43d
...
@@ -292,18 +292,21 @@ class Timers:
...
@@ -292,18 +292,21 @@ class Timers:
class
GlobalMemoryBuffer
:
class
GlobalMemoryBuffer
:
"Global buffer to avoid dynamic memory allocations"
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
buffer
=
{}
self
.
buffer
=
{}
def
allocate
_tensor
(
self
,
tensor_shape
,
dtype
):
def
get
_tensor
(
self
,
tensor_shape
,
dtype
,
name
):
required_len
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
required_len
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
if
self
.
buffer
.
get
(
dtype
,
None
)
is
None
or
self
.
buffer
[
dtype
].
numel
()
<
required_len
:
if
self
.
buffer
.
get
((
name
,
dtype
),
None
)
is
None
or
\
self
.
buffer
[
dtype
]
=
torch
.
empty
(
required_len
,
self
.
buffer
[(
name
,
dtype
)].
numel
()
<
required_len
:
dtype
=
dtype
,
self
.
buffer
[(
name
,
dtype
)]
=
\
device
=
torch
.
cuda
.
current_device
(),
torch
.
empty
(
required_len
,
requires_grad
=
False
)
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
return
self
.
buffer
[
dtype
][
0
:
required_len
].
view
(
*
tensor_shape
)
requires_grad
=
False
)
return
self
.
buffer
[(
name
,
dtype
)][
0
:
required_len
].
view
(
*
tensor_shape
)
megatron/model/transformer.py
View file @
739cb43d
...
@@ -234,9 +234,9 @@ class CoreAttention(MegatronModule):
...
@@ -234,9 +234,9 @@ class CoreAttention(MegatronModule):
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting input tensor: [b * np, sq, sk]
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer
=
get_global_memory_buffer
().
allocate
_tensor
(
matmul_input_buffer
=
get_global_memory_buffer
().
get
_tensor
(
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
]),
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
]),
dtype
=
query_layer
.
dtype
)
query_layer
.
dtype
,
"mpu"
)
# Raw attention scores. [b * np, sq, sk]
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_result
=
torch
.
baddbmm
(
...
...
megatron/mpu/layers.py
View file @
739cb43d
...
@@ -221,7 +221,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -221,7 +221,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
all_gather_buffer
=
\
get_global_memory_buffer
().
allocate
_tensor
(
dim_size
,
dtype
=
input
.
dtype
)
get_global_memory_buffer
().
get
_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
torch
.
distributed
.
_all_gather_base
(
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
all_gather_buffer
,
input
,
input
,
...
@@ -246,7 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -246,7 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
all_gather_buffer
=
\
get_global_memory_buffer
().
allocate
_tensor
(
dim_size
,
dtype
=
input
.
dtype
)
get_global_memory_buffer
().
get
_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
handle
=
torch
.
distributed
.
_all_gather_base
(
handle
=
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
all_gather_buffer
,
input
,
input
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment