Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wxj
NeMo
Commits
bc5c7fa7
Commit
bc5c7fa7
authored
Jan 07, 2025
by
wxj
Browse files
第一次测试提交
parent
70fddd0f
Changes
290
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3078 additions
and
0 deletions
+3078
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/fp16_deprecated/loss_scaler.py
...0.7.0.beta/megatron/legacy/fp16_deprecated/loss_scaler.py
+26
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/fused_kernels/__init__.py
...ore_r0.7.0.beta/megatron/legacy/fused_kernels/__init__.py
+75
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/fused_kernels/compat.h
...M-core_r0.7.0.beta/megatron/legacy/fused_kernels/compat.h
+17
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/fused_kernels/tests/__init__.py
....7.0.beta/megatron/legacy/fused_kernels/tests/__init__.py
+0
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/fused_kernels/tests/test_fused_kernels.py
...megatron/legacy/fused_kernels/tests/test_fused_kernels.py
+388
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/fused_kernels/type_shim.h
...ore_r0.7.0.beta/megatron/legacy/fused_kernels/type_shim.h
+103
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/indexer.py
Megatron-LM-core_r0.7.0.beta/megatron/legacy/indexer.py
+129
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/__init__.py
...ron-LM-core_r0.7.0.beta/megatron/legacy/model/__init__.py
+10
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/bert_model.py
...n-LM-core_r0.7.0.beta/megatron/legacy/model/bert_model.py
+257
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/biencoder_model.py
...core_r0.7.0.beta/megatron/legacy/model/biencoder_model.py
+328
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/classification.py
...-core_r0.7.0.beta/megatron/legacy/model/classification.py
+101
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/enums.py
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/enums.py
+21
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/fused_bias_gelu.py
...core_r0.7.0.beta/megatron/legacy/model/fused_bias_gelu.py
+44
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/fused_layer_norm.py
...ore_r0.7.0.beta/megatron/legacy/model/fused_layer_norm.py
+96
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/fused_softmax.py
...M-core_r0.7.0.beta/megatron/legacy/model/fused_softmax.py
+213
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/gpt_model.py
...on-LM-core_r0.7.0.beta/megatron/legacy/model/gpt_model.py
+122
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/language_model.py
...-core_r0.7.0.beta/megatron/legacy/model/language_model.py
+626
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/module.py
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/module.py
+206
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/multiple_choice.py
...core_r0.7.0.beta/megatron/legacy/model/multiple_choice.py
+112
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/realm_model.py
...-LM-core_r0.7.0.beta/megatron/legacy/model/realm_model.py
+204
-0
No files found.
Too many changes to show.
To preserve performance only
290 of 290+
files are displayed.
Plain diff
Email patch
Megatron-LM-core_r0.7.0.beta/megatron/legacy/fp16_deprecated/loss_scaler.py
0 → 100755
View file @
bc5c7fa7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""For backward compatibility, we need the class definitions to deserialize."""
class
LossScaler
:
def
__init__
(
self
,
scale
=
1
):
self
.
cur_scale
=
scale
class
DynamicLossScaler
:
def
__init__
(
self
,
init_scale
=
2
**
32
,
scale_factor
=
2.
,
scale_window
=
1000
,
min_scale
=
1
,
delayed_shift
=
1
,
consecutive_hysteresis
=
False
):
self
.
cur_scale
=
init_scale
self
.
cur_iter
=
0
self
.
last_overflow_iter
=
-
1
self
.
scale_factor
=
scale_factor
self
.
scale_window
=
scale_window
self
.
min_scale
=
min_scale
self
.
delayed_shift
=
delayed_shift
self
.
cur_hysteresis
=
delayed_shift
self
.
consecutive_hysteresis
=
consecutive_hysteresis
Megatron-LM-core_r0.7.0.beta/megatron/legacy/fused_kernels/__init__.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
os
import
pathlib
import
subprocess
from
torch.utils
import
cpp_extension
# Setting this param to a list has a problem of generating different
# compilation commands (with diferent order of architectures) and
# leading to recompilation of fused kernels. Set it to empty string
# to avoid recompilation and assign arch flags explicity in
# extra_cuda_cflags below
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
""
def
load
(
args
):
# Check if cuda 11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
bare_metal_minor
=
_get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
if
int
(
bare_metal_minor
)
>=
8
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_90,code=sm_90'
)
# Build path
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
buildpath
=
srcpath
/
"build"
_create_build_dir
(
buildpath
)
# Helper function to build the kernels.
def
_cpp_extention_load_helper
(
name
,
sources
,
extra_cuda_flags
):
return
cpp_extension
.
load
(
name
=
name
,
sources
=
sources
,
build_directory
=
buildpath
,
extra_cflags
=
[
"-O3"
,
],
extra_cuda_cflags
=
[
"-O3"
,
"-gencode"
,
"arch=compute_70,code=sm_70"
,
"--use_fast_math"
,
]
+
extra_cuda_flags
+
cc_flag
,
verbose
=
(
args
.
rank
==
0
),
)
def
_get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
(
[
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
_create_build_dir
(
buildpath
):
try
:
os
.
mkdir
(
buildpath
)
except
OSError
:
if
not
os
.
path
.
isdir
(
buildpath
):
print
(
f
"Creation of the build directory
{
buildpath
}
failed"
)
Megatron-LM-core_r0.7.0.beta/megatron/legacy/fused_kernels/compat.h
0 → 100644
View file @
bc5c7fa7
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
Megatron-LM-core_r0.7.0.beta/megatron/legacy/fused_kernels/tests/__init__.py
0 → 100644
View file @
bc5c7fa7
Megatron-LM-core_r0.7.0.beta/megatron/legacy/fused_kernels/tests/test_fused_kernels.py
0 → 100644
View file @
bc5c7fa7
import
math
import
torch
from
torch.nn
import
LayerNorm
from
megatron.legacy.model.enums
import
AttnMaskType
from
megatron.legacy.model.fused_layer_norm
import
MixedFusedLayerNorm
from
megatron.legacy.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.legacy.model.utils
import
attention_mask_func
from
megatron.legacy.fused_kernels
import
load
def
test_load_fused_kernels
():
try
:
import
fused_layer_norm_cuda
import
scaled_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
import
torch
print
(
"[Success] load_fused_kernels"
)
except
ImportError
as
e
:
print
(
"[Fail] load_fused_kernels"
)
raise
e
def
test_fused_softmax
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
embedding_output
=
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
# (bsz, 1, 1, seq_len)
mask
=
bert
.
get_extended_attention_mask
(
attention_mask
=
tokens
[
"attention_mask"
].
cuda
(),
input_shape
=
tokens
[
"input_ids"
].
shape
,
device
=
bert
.
device
,
)
# (bsz, 1, seq_len, seq_len)
mask
=
mask
.
repeat
(
1
,
1
,
mask
.
size
()[
-
1
],
1
)
attention
=
bert
.
encoder
.
layer
[
0
].
attention
.
self
key_layer
=
attention
.
transpose_for_scores
(
attention
.
key
(
embedding_output
))
query_layer
=
attention
.
transpose_for_scores
(
attention
.
query
(
embedding_output
))
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
/=
math
.
sqrt
(
key_layer
.
size
()[
-
1
])
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attention_scores
,
(
mask
!=
0
),
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attention_scores
,
(
mask
!=
0
),
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_fused_upper_triangle_mask_softmax
():
gpt
=
GPT2Model
.
from_pretrained
(
"gpt2"
).
cuda
().
half
()
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi"
# 24
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
attention_mask
=
tokens
[
"attention_mask"
].
cuda
()
attention_mask
=
attention_mask
.
view
(
attention_mask
.
size
(
0
),
-
1
)
attention_mask
=
attention_mask
[:,
None
,
None
,
:]
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
attention_mask
=
attention_mask
.
repeat
(
1
,
1
,
attention_mask
.
size
()[
-
1
],
1
)
attn
=
gpt
.
h
[
0
]
hidden_states
=
gpt
.
wte
(
tokens
[
"input_ids"
].
cuda
())
q
,
k
,
v
=
attn
.
attn
.
c_attn
(
hidden_states
).
split
(
768
,
dim
=-
1
)
q
=
attn
.
attn
.
_split_heads
(
q
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
k
=
attn
.
attn
.
_split_heads
(
k
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
attn_weights
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
sq
,
sk
=
q
.
size
(
-
2
),
k
.
size
(
-
2
)
causal_mask
=
attn
.
attn
.
bias
[:,
:,
sk
-
sq
:
sk
,
:
sk
].
bool
()
total_mask
=
~
(
causal_mask
&
(
attention_mask
==
0
))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attn_weights
,
total_mask
,
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attn_weights
,
total_mask
,
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_layer_norm
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
# [bsz, seq_len, d_model]
embedding_output
=
(
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
.
cuda
()
.
half
()
)
fused_layernorm_layer
=
(
MixedFusedLayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
torch_layernorm_layer
=
(
LayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
fused_output
=
fused_layernorm_layer
(
embedding_output
)
torch_output
=
torch_layernorm_layer
(
embedding_output
)
test_result
=
(
fused_output
-
torch_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
return
attention_scores
def
forward_torch_softmax
(
input
,
mask
,
scale
):
input
=
input
*
scale
mask_output
=
attention_mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
return
probs
def
test_masked_softmax_forward
():
import
scaled_masked_softmax_cuda
batch
=
2
attn
=
16
scale_t
=
torch
.
tensor
([
1.0
])
for
qlen
in
[
128
,
256
,
1024
,
2048
,
4096
]:
for
klen
in
[
128
,
256
,
1024
,
2048
]:
inputs
=
torch
.
normal
(
0
,
2
,
(
batch
,
attn
,
qlen
,
klen
),
dtype
=
torch
.
float16
,
device
=
'cuda:0'
)
masks
=
torch
.
randint
(
0
,
2
,
(
batch
,
1
,
qlen
,
klen
),
dtype
=
torch
.
bool
,
device
=
'cuda:0'
)
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
masks
,
scale_t
[
0
].
item
())
softmax_results_torch
=
forward_torch_softmax
(
inputs
,
masks
,
scale_t
[
0
].
item
())
error
=
(
softmax_results_torch
-
softmax_results
).
abs
().
max
()
assert
error
<
1e-3
def
test_masked_softmax_backward
():
import
scaled_masked_softmax_cuda
batch
=
2
attn
=
16
scale_t
=
torch
.
tensor
([
1.0
])
for
qlen
in
[
128
,
256
,
1024
,
2048
,
4096
]:
for
klen
in
[
128
,
256
,
1024
,
2048
]:
inputs
=
torch
.
normal
(
0
,
2
,
(
batch
,
attn
,
qlen
,
klen
),
dtype
=
torch
.
float16
,
device
=
'cuda:0'
)
backward
=
torch
.
rand_like
(
inputs
,
dtype
=
torch
.
float16
,
device
=
'cuda:0'
)
masks
=
torch
.
randint
(
0
,
2
,
(
batch
,
1
,
qlen
,
klen
),
dtype
=
torch
.
bool
,
device
=
'cuda:0'
)
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
masks
,
scale_t
[
0
].
item
())
back_grad
=
scaled_masked_softmax_cuda
.
backward
(
backward
,
softmax_results
,
scale_t
[
0
].
item
())
inputs
.
requires_grad
=
True
softmax_results_torch
=
forward_torch_softmax
(
inputs
,
masks
,
scale_t
[
0
].
item
())
softmax_results_torch
.
backward
(
backward
)
error
=
(
back_grad
-
inputs
.
grad
).
abs
().
max
()
assert
error
<
1e-3
def
test_allmasked_softmax_forward
():
import
scaled_masked_softmax_cuda
batch
=
2
attn
=
16
scale_t
=
torch
.
tensor
([
1.0
])
for
qlen
in
[
128
,
256
,
1024
,
2048
,
4096
]:
for
klen
in
[
128
,
256
,
1024
,
2048
]:
inputs
=
torch
.
normal
(
0
,
2
,
(
batch
,
attn
,
qlen
,
klen
),
dtype
=
torch
.
float16
,
device
=
'cuda:0'
)
masks
=
torch
.
ones
((
batch
,
1
,
qlen
,
klen
),
dtype
=
torch
.
bool
,
device
=
'cuda:0'
)
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
masks
,
scale_t
[
0
].
item
())
softmax_results_torch
=
torch
.
zeros_like
(
inputs
)
error
=
(
softmax_results_torch
-
softmax_results
).
abs
().
max
()
assert
error
==
0.0
def
test_allmasked_softmax_backward
():
import
scaled_masked_softmax_cuda
batch
=
2
attn
=
16
scale_t
=
torch
.
tensor
([
1.0
])
for
qlen
in
[
128
,
256
,
1024
,
2048
,
4096
]:
for
klen
in
[
128
,
256
,
1024
,
2048
]:
inputs
=
torch
.
normal
(
0
,
2
,
(
batch
,
attn
,
qlen
,
klen
),
dtype
=
torch
.
float16
,
device
=
'cuda:0'
)
backward
=
torch
.
rand_like
(
inputs
,
dtype
=
torch
.
float16
,
device
=
'cuda:0'
)
masks
=
torch
.
ones
((
batch
,
1
,
qlen
,
klen
),
dtype
=
torch
.
bool
,
device
=
'cuda:0'
)
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
masks
,
scale_t
[
0
].
item
())
back_grad
=
scaled_masked_softmax_cuda
.
backward
(
backward
,
softmax_results
,
scale_t
[
0
].
item
())
inputs
.
requires_grad
=
True
softmax_results_torch
=
forward_torch_softmax
(
inputs
,
masks
,
scale_t
[
0
].
item
())
softmax_results_torch
.
backward
(
backward
)
error
=
(
back_grad
-
inputs
.
grad
).
abs
().
max
()
assert
error
<
1e-3
if
__name__
==
"__main__"
:
try
:
from
transformers
import
BertTokenizer
,
GPT2Tokenizer
from
transformers.models.bert.modeling_bert
import
BertModel
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Model
import
transformers
transformers
.
logging
.
set_verbosity
(
transformers
.
logging
.
FATAL
,
)
except
:
print
(
"
\n
[Fail] Please install `transformers` package to test fused kernels
\n
"
)
exit
(
-
1
)
load
()
test_masked_softmax_forward
()
test_masked_softmax_backward
()
test_allmasked_softmax_forward
()
test_allmasked_softmax_backward
()
test_load_fused_kernels
()
test_fused_softmax
()
test_fused_upper_triangle_mask_softmax
()
test_layer_norm
()
Megatron-LM-core_r0.7.0.beta/megatron/legacy/fused_kernels/type_shim.h
0 → 100644
View file @
bc5c7fa7
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
Megatron-LM-core_r0.7.0.beta/megatron/legacy/indexer.py
0 → 100644
View file @
bc5c7fa7
import
sys
import
time
import
torch
import
torch.distributed
as
dist
from
megatron.training
import
get_args
,
print_rank_0
from
megatron.core
import
mpu
from
megatron.training.checkpointing
import
load_biencoder_checkpoint
from
megatron.legacy.data.orqa_wiki_dataset
import
get_open_retrieval_wiki_dataset
from
megatron.legacy.data.orqa_wiki_dataset
import
get_open_retrieval_batch
from
megatron.legacy.data.biencoder_dataset_utils
import
get_one_epoch_dataloader
from
megatron.legacy.data.realm_index
import
detach
,
OpenRetreivalDataStore
from
megatron.legacy.model.biencoder_model
import
get_model_provider
from
megatron.training
import
get_model
class
IndexBuilder
(
object
):
"""
Object for taking one pass over a dataset and creating a BlockData of its
embeddings
"""
def
__init__
(
self
):
args
=
get_args
()
self
.
model
=
None
self
.
dataloader
=
None
self
.
evidence_embedder_obj
=
None
self
.
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert
not
(
args
.
load
and
args
.
ict_load
)
self
.
log_interval
=
args
.
indexer_log_interval
self
.
batch_size
=
args
.
indexer_batch_size
self
.
load_attributes
()
self
.
is_main_builder
=
mpu
.
get_data_parallel_rank
()
==
0
self
.
num_total_builders
=
mpu
.
get_data_parallel_world_size
()
self
.
iteration
=
self
.
total_processed
=
0
def
load_attributes
(
self
):
"""
Load the necessary attributes: model, dataloader and empty BlockData
"""
only_context_model
=
True
if
self
.
biencoder_shared_query_context_model
:
only_context_model
=
False
model
=
get_model
(
get_model_provider
(
only_context_model
=
\
only_context_model
,
biencoder_shared_query_context_model
=
\
self
.
biencoder_shared_query_context_model
))
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_context_model
=
only_context_model
)
assert
len
(
self
.
model
)
==
1
self
.
model
[
0
].
eval
()
self
.
dataset
=
get_open_retrieval_wiki_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
\
self
.
batch_size
))
self
.
evidence_embedder_obj
=
OpenRetreivalDataStore
(
\
load_from_path
=
False
)
def
track_and_report_progress
(
self
,
batch_size
):
"""
Utility function for tracking progress
"""
self
.
iteration
+=
1
self
.
total_processed
+=
batch_size
*
self
.
num_total_builders
if
self
.
is_main_builder
and
self
.
iteration
%
self
.
log_interval
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
def
build_and_save_index
(
self
):
"""
Goes through one epoch of the dataloader and adds all data to this
instance's BlockData.
The copy of BlockData is saved as a shard, which when run in a
distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
"""
assert
len
(
self
.
model
)
==
1
unwrapped_model
=
self
.
model
[
0
]
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
unwrapped_model
=
unwrapped_model
.
module
while
True
:
try
:
# batch also has query_tokens and query_pad_data
row_id
,
context_tokens
,
context_mask
,
context_types
,
\
context_pad_mask
=
get_open_retrieval_batch
(
\
self
.
dataloader
)
except
(
StopIteration
,
IndexError
):
break
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
assert
context_mask
.
dtype
==
torch
.
bool
context_logits
=
unwrapped_model
.
embed_text
(
unwrapped_model
.
context_model
,
context_tokens
,
context_mask
,
context_types
)
context_logits
=
detach
(
context_logits
)
row_id
=
detach
(
row_id
)
self
.
evidence_embedder_obj
.
add_block_data
(
row_id
,
context_logits
)
self
.
track_and_report_progress
(
batch_size
=
len
(
row_id
))
# This process signals to finalize its shard and then synchronize with
# the other processes
self
.
evidence_embedder_obj
.
save_shard
()
torch
.
distributed
.
barrier
()
del
self
.
model
# rank 0 process builds the final copy
if
self
.
is_main_builder
:
self
.
evidence_embedder_obj
.
merge_shards_and_save
()
# make sure that every single piece of data was embedded
assert
len
(
self
.
evidence_embedder_obj
.
embed_data
)
==
\
len
(
self
.
dataset
)
self
.
evidence_embedder_obj
.
clear
()
# complete building the final copy
torch
.
distributed
.
barrier
()
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/__init__.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.rms_norm
import
RMSNorm
from
.bert_model
import
BertModel
from
.gpt_model
import
GPTModel
from
.t5_model
import
T5Model
from
.language_model
import
get_language_model
from
.module
import
Float16Module
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/bert_model.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""BERT model."""
import
torch
from
megatron.training
import
get_args
from
megatron.core
import
tensor_parallel
from
megatron.legacy.model.enums
import
AttnMaskType
from
megatron.legacy.model.language_model
import
parallel_lm_logits
from
megatron.legacy.model.language_model
import
get_language_model
from
megatron.legacy.model.utils
import
get_norm
from
megatron.legacy.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.legacy.model.utils
import
get_linear_layer
from
megatron.legacy.model.utils
import
init_method_normal
from
megatron.legacy.model.utils
import
scaled_init_method_normal
from
.module
import
MegatronModule
def
bert_extended_attention_mask
(
attention_mask
):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s
=
attention_mask
.
unsqueeze
(
1
)
# [b, s, 1]
attention_mask_bs1
=
attention_mask
.
unsqueeze
(
2
)
# [b, s, s]
attention_mask_bss
=
attention_mask_b1s
*
attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask
=
attention_mask_bss
.
unsqueeze
(
1
)
# Convert attention mask to binary:
extended_attention_mask
=
(
extended_attention_mask
<
0.5
)
return
extended_attention_mask
def
bert_position_ids
(
token_ids
):
# Create position ids
seq_length
=
token_ids
.
size
(
1
)
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
token_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
token_ids
)
return
position_ids
class
BertLMHead
(
MegatronModule
):
"""Masked LM head for Bert
Args:
config: TransformerConfig object
mpu_vocab_size: model parallel size of vocabulary.
parallel_output: whether output logits being distributed or not.
"""
def
__init__
(
self
,
mpu_vocab_size
,
config
,
parallel_output
):
super
().
__init__
(
config
=
config
)
args
=
get_args
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
tensor_parallel
.
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
1
)
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
init_method
)
setattr
(
self
.
dense
.
weight
,
'sequence_parallel'
,
config
.
sequence_parallel
)
setattr
(
self
.
dense
.
bias
,
'sequence_parallel'
,
config
.
sequence_parallel
)
self
.
norm
=
get_norm
(
config
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
self
.
gelu
=
openai_gelu
elif
args
.
onnx_safe
:
self
.
gelu
=
erf_gelu
def
forward
(
self
,
hidden_states
,
word_embeddings_weight
):
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
gelu
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
output
=
parallel_lm_logits
(
hidden_states
,
word_embeddings_weight
,
self
.
parallel_output
,
bias
=
self
.
bias
)
return
output
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customize load."""
# Handle renaming layernorm -> norm in component names
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
newkey
=
key
.
replace
(
"layernorm"
,
"norm"
)
state_dict_
[
newkey
]
=
state_dict
[
key
]
super
().
load_state_dict
(
state_dict_
,
strict
)
def
post_language_model_processing
(
lm_output
,
pooled_output
,
lm_head
,
binary_head
,
lm_labels
,
logit_weights
,
fp16_lm_cross_entropy
):
# Output.
lm_logits
=
lm_head
(
lm_output
,
logit_weights
)
binary_logits
=
None
if
binary_head
is
not
None
:
binary_logits
=
binary_head
(
pooled_output
)
if
lm_labels
is
None
:
# [s b h] => [b s h]
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
(),
binary_logits
else
:
# [b s] => [s b]
lm_labels
=
lm_labels
.
transpose
(
0
,
1
).
contiguous
()
# lm_logits : [s, b, h] and lm_labels: [s, b]
if
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
# [s, b] => [b s]
lm_loss
=
lm_loss
.
transpose
(
0
,
1
).
contiguous
()
return
lm_loss
,
binary_logits
class
BertModel
(
MegatronModule
):
"""Bert Language model."""
def
__init__
(
self
,
config
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
().
__init__
(
config
=
config
)
args
=
get_args
()
# TODO this option is not yet implemented in BERT
assert
args
.
untie_embeddings_and_output_weights
is
False
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
add_binary_head
=
add_binary_head
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
return_embeddings
=
args
.
output_bert_embeddings
if
self
.
return_embeddings
:
assert
self
.
post_process
and
self
.
add_binary_head
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
self
.
add_binary_head
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
()
if
self
.
post_process
:
self
.
lm_head
=
BertLMHead
(
self
.
shared_embedding_or_output_weight
().
size
(
0
),
config
,
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
self
.
binary_head
=
None
if
self
.
add_binary_head
:
self
.
binary_head
=
get_linear_layer
(
config
.
hidden_size
,
2
,
config
.
init_method
)
self
.
_binary_head_key
=
'binary_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.legacy.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
bert_model_input
,
attention_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
input_ids
=
bert_model_input
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
if
self
.
post_process
and
self
.
add_binary_head
:
lm_output
,
pooled_output
=
lm_output
# Return pooled output (e.g., when computing Bert embeddings).
if
self
.
return_embeddings
:
# Sum attention mask.
embeddings
=
torch
.
transpose
(
lm_output
,
0
,
1
)
masks
=
torch
.
sum
(
attention_mask
,
dim
=
1
)
# Collect masked embeddings.
output
=
torch
.
zeros
(
size
=
(
embeddings
.
shape
[
0
],
embeddings
.
shape
[
2
]),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
for
i
,
(
embedding
,
mask
)
in
enumerate
(
zip
(
embeddings
,
masks
)):
output
[
i
,
:]
=
torch
.
mean
(
embedding
[
1
:
mask
-
1
],
dim
=
0
)
return
output
else
:
pooled_output
=
None
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
pooled_output
,
self
.
lm_head
,
self
.
binary_head
,
lm_labels
,
self
.
shared_embedding_or_output_weight
(),
self
.
fp16_lm_cross_entropy
)
else
:
return
lm_output
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
post_process
:
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
post_process
and
self
.
add_binary_head
:
state_dict_
[
self
.
_binary_head_key
]
\
=
self
.
binary_head
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
self
.
post_process
:
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
if
self
.
post_process
and
self
.
add_binary_head
:
self
.
binary_head
.
load_state_dict
(
state_dict
[
self
.
_binary_head_key
],
strict
=
strict
)
# Load word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/biencoder_model.py
0 → 100644
View file @
bc5c7fa7
import
os
import
torch
import
sys
from
megatron.training
import
get_args
,
print_rank_0
,
get_tokenizer
from
megatron.core
import
mpu
from
megatron.training.checkpointing
import
fix_query_key_value_ordering
from
megatron.training.checkpointing
import
get_checkpoint_tracker_filename
from
megatron.training.checkpointing
import
get_checkpoint_name
from
megatron.legacy.model.bert_model
import
bert_position_ids
from
megatron.legacy.model.enums
import
AttnMaskType
from
megatron.legacy.model.language_model
import
get_language_model
from
megatron.legacy.model.utils
import
get_linear_layer
from
megatron.legacy.model.utils
import
init_method_normal
from
megatron.legacy.model.utils
import
scaled_init_method_normal
from
.module
import
MegatronModule
def
get_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building Bienoder model ...'
)
model
=
biencoder_model_provider
(
only_query_model
=
only_query_model
,
only_context_model
=
only_context_model
,
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
return
model_provider
def
biencoder_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
,
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
assert
mpu
.
get_tensor_model_parallel_world_size
()
==
1
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
,
\
"Model parallel size > 1 not supported for ICT"
print_rank_0
(
'building BiEncoderModel...'
)
# simpler to just keep using 2 tokentypes since
# the LM we initialize with has 2 tokentypes
model
=
BiEncoderModel
(
num_tokentypes
=
2
,
parallel_output
=
False
,
only_query_model
=
only_query_model
,
only_context_model
=
only_context_model
,
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
class
BiEncoderModel
(
MegatronModule
):
"""Bert-based module for Biencoder model."""
def
__init__
(
self
,
num_tokentypes
=
1
,
parallel_output
=
True
,
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
,
pre_process
=
True
,
post_process
=
True
):
super
(
BiEncoderModel
,
self
).
__init__
()
args
=
get_args
()
bert_kwargs
=
dict
(
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
,
pre_process
=
pre_process
,
post_process
=
post_process
)
self
.
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
assert
not
(
only_context_model
and
only_query_model
)
self
.
use_context_model
=
not
only_query_model
self
.
use_query_model
=
not
only_context_model
self
.
biencoder_projection_dim
=
args
.
biencoder_projection_dim
if
self
.
biencoder_shared_query_context_model
:
self
.
model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_model_key
=
'shared_model'
self
.
query_model
,
self
.
context_model
=
self
.
model
,
self
.
model
else
:
if
self
.
use_query_model
:
# this model embeds (pseudo-)queries - Embed_input in the paper
self
.
query_model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_query_key
=
'query_model'
if
self
.
use_context_model
:
# this model embeds evidence blocks - Embed_doc in the paper
self
.
context_model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_context_key
=
'context_model'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.legacy.model.transformer.set_input_tensor()"""
# this is just a placeholder and will be needed when model
# parallelism will be used
# self.language_model.set_input_tensor(input_tensor)
return
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
query_types
,
context_tokens
,
context_attention_mask
,
context_types
):
"""Run a forward pass for each of the models and
return the respective embeddings."""
if
self
.
use_query_model
:
query_logits
=
self
.
embed_text
(
self
.
query_model
,
query_tokens
,
query_attention_mask
,
query_types
)
else
:
raise
ValueError
(
"Cannot embed query without the query model."
)
if
self
.
use_context_model
:
context_logits
=
self
.
embed_text
(
self
.
context_model
,
context_tokens
,
context_attention_mask
,
context_types
)
else
:
raise
ValueError
(
"Cannot embed block without the block model."
)
return
query_logits
,
context_logits
@
staticmethod
def
embed_text
(
model
,
tokens
,
attention_mask
,
token_types
):
"""Embed a batch of tokens using the model"""
logits
=
model
(
tokens
,
attention_mask
,
token_types
)
return
logits
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
if
self
.
biencoder_shared_query_context_model
:
state_dict_
[
self
.
_model_key
]
=
\
self
.
model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
else
:
if
self
.
use_query_model
:
state_dict_
[
self
.
_query_key
]
=
\
self
.
query_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
use_context_model
:
state_dict_
[
self
.
_context_key
]
=
\
self
.
context_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
if
self
.
biencoder_shared_query_context_model
:
print_rank_0
(
"Loading shared query-context model"
)
self
.
model
.
load_state_dict
(
state_dict
[
self
.
_model_key
],
\
strict
=
strict
)
else
:
if
self
.
use_query_model
:
print_rank_0
(
"Loading query model"
)
self
.
query_model
.
load_state_dict
(
\
state_dict
[
self
.
_query_key
],
strict
=
strict
)
if
self
.
use_context_model
:
print_rank_0
(
"Loading context model"
)
self
.
context_model
.
load_state_dict
(
\
state_dict
[
self
.
_context_key
],
strict
=
strict
)
def
init_state_dict_from_bert
(
self
):
"""Initialize the state from a pretrained BERT model
on iteration zero of ICT pretraining"""
args
=
get_args
()
if
args
.
bert_load
is
None
:
print_rank_0
(
"bert-load argument is None"
)
return
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
bert_load
)
if
not
os
.
path
.
isfile
(
tracker_filename
):
raise
FileNotFoundError
(
"Could not find BERT checkpoint"
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
checkpoint_name
=
get_checkpoint_name
(
args
.
bert_load
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading BERT checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
# Load the checkpoint.
try
:
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
except
ModuleNotFoundError
:
from
megatron.legacy.fp16_deprecated
import
loss_scaler
# For backward compatibility.
print_rank_0
(
' > deserializing using the old code structure ...'
)
sys
.
modules
[
'fp16.loss_scaler'
]
=
sys
.
modules
[
'megatron.fp16_deprecated.loss_scaler'
]
sys
.
modules
[
'megatron.fp16.loss_scaler'
]
=
sys
.
modules
[
'megatron.fp16_deprecated.loss_scaler'
]
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
sys
.
modules
.
pop
(
'fp16.loss_scaler'
,
None
)
sys
.
modules
.
pop
(
'megatron.fp16.loss_scaler'
,
None
)
except
BaseException
:
print_rank_0
(
'could not load the BERT checkpoint'
)
sys
.
exit
()
checkpoint_version
=
state_dict
.
get
(
'checkpoint_version'
,
0
)
# load the LM state dict into each model
model_dict
=
state_dict
[
'model'
][
'language_model'
]
if
self
.
biencoder_shared_query_context_model
:
self
.
model
.
language_model
.
load_state_dict
(
model_dict
)
fix_query_key_value_ordering
(
self
.
model
,
checkpoint_version
)
else
:
if
self
.
use_query_model
:
self
.
query_model
.
language_model
.
load_state_dict
(
model_dict
)
# give each model the same ict_head to begin with as well
if
self
.
biencoder_projection_dim
>
0
:
query_proj_state_dict
=
\
self
.
state_dict_for_save_checkpoint
()
\
[
self
.
_query_key
][
'projection_enc'
]
fix_query_key_value_ordering
(
self
.
query_model
,
checkpoint_version
)
if
self
.
use_context_model
:
self
.
context_model
.
language_model
.
load_state_dict
(
model_dict
)
if
self
.
query_model
is
not
None
and
\
self
.
biencoder_projection_dim
>
0
:
self
.
context_model
.
projection_enc
.
load_state_dict
\
(
query_proj_state_dict
)
fix_query_key_value_ordering
(
self
.
context_model
,
checkpoint_version
)
class
PretrainedBertModel
(
MegatronModule
):
"""BERT-based encoder for queries or contexts used for
learned information retrieval."""
def
__init__
(
self
,
num_tokentypes
=
2
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
PretrainedBertModel
,
self
).
__init__
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
self
.
pad_id
=
tokenizer
.
pad
self
.
biencoder_projection_dim
=
args
.
biencoder_projection_dim
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
if
args
.
biencoder_projection_dim
>
0
:
self
.
projection_enc
=
get_linear_layer
(
args
.
hidden_size
,
args
.
biencoder_projection_dim
,
init_method
)
self
.
_projection_enc_key
=
'projection_enc'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
)
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
# This mask will be used in average-pooling and max-pooling
pool_mask
=
(
input_ids
==
self
.
pad_id
).
unsqueeze
(
2
)
# Taking the representation of the [CLS] token of BERT
pooled_output
=
lm_output
[
0
,
:,
:]
# Converting to float16 dtype
pooled_output
=
pooled_output
.
to
(
lm_output
.
dtype
)
# Output.
if
self
.
biencoder_projection_dim
:
pooled_output
=
self
.
projection_enc
(
pooled_output
)
return
pooled_output
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
biencoder_projection_dim
>
0
:
state_dict_
[
self
.
_projection_enc_key
]
=
\
self
.
projection_enc
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
print_rank_0
(
"loading pretrained weights"
)
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
self
.
biencoder_projection_dim
>
0
:
print_rank_0
(
"loading projection head weights"
)
self
.
projection_enc
.
load_state_dict
(
state_dict
[
self
.
_projection_enc_key
],
strict
=
strict
)
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/classification.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Classification model."""
import
torch
from
megatron.training
import
get_args
,
print_rank_last
from
megatron.legacy.model.enums
import
AttnMaskType
from
megatron.legacy.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
from
megatron.legacy.model.language_model
import
get_language_model
from
megatron.legacy.model.utils
import
get_linear_layer
from
megatron.legacy.model.utils
import
init_method_normal
from
megatron.legacy.model.utils
import
scaled_init_method_normal
from
.module
import
MegatronModule
class
Classification
(
MegatronModule
):
def
__init__
(
self
,
config
,
num_classes
,
num_tokentypes
=
2
,
pre_process
=
True
,
post_process
=
True
):
super
().
__init__
(
config
=
config
,
share_embeddings_and_output_weights
=
False
)
args
=
get_args
()
self
.
num_classes
=
num_classes
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
# Multi-choice head.
if
self
.
post_process
:
self
.
classification_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
classification_head
=
get_linear_layer
(
args
.
hidden_size
,
self
.
num_classes
,
config
.
init_method
)
self
.
_classification_head_key
=
'classification_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.legacy.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
model_input
,
attention_mask
,
tokentype_ids
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
input_ids
=
model_input
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
if
self
.
post_process
:
_
,
pooled_output
=
lm_output
classification_output
=
self
.
classification_dropout
(
pooled_output
)
classification_logits
=
self
.
classification_head
(
classification_output
)
# Reshape back to separate choices.
classification_logits
=
classification_logits
.
view
(
-
1
,
self
.
num_classes
)
return
classification_logits
return
lm_output
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
post_process
:
state_dict_
[
self
.
_classification_head_key
]
\
=
self
.
classification_head
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
self
.
post_process
:
if
self
.
_classification_head_key
in
state_dict
:
self
.
classification_head
.
load_state_dict
(
state_dict
[
self
.
_classification_head_key
],
strict
=
strict
)
else
:
print_rank_last
(
'***WARNING*** could not find {} in the checkpoint, '
'initializing to random'
.
format
(
self
.
_classification_head_key
))
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/enums.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
enum
class
LayerType
(
enum
.
Enum
):
encoder
=
1
decoder
=
2
retro_encoder
=
3
retro_decoder
=
4
retro_decoder_with_retriever
=
5
class
AttnType
(
enum
.
Enum
):
self_attn
=
1
cross_attn
=
2
class
AttnMaskType
(
enum
.
Enum
):
padding
=
1
causal
=
2
# For backward compatibility with old model checkpoints
from
megatron.core.enums
import
ModelType
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/fused_bias_gelu.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
torch
from
megatron.core.jit
import
jit_fuser
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
jit_fuser
def
bias_gelu
(
bias
,
y
):
x
=
bias
+
y
return
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
jit_fuser
def
bias_gelu_back
(
g
,
bias
,
y
):
x
=
bias
+
y
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
return
ff
*
g
class
GeLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
):
ctx
.
save_for_backward
(
input
,
bias
)
return
bias_gelu
(
bias
,
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
tmp
=
bias_gelu_back
(
grad_output
,
bias
,
input
)
return
tmp
,
tmp
bias_gelu_impl
=
GeLUFunction
.
apply
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/fused_layer_norm.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex
with some changes. """
import
numbers
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
init
import
importlib
from
megatron.core.utils
import
make_viewless_tensor
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
HAVE_PERSIST_LAYER_NORM
=
True
except
:
HAVE_PERSIST_LAYER_NORM
=
False
try
:
from
apex.normalization.fused_layer_norm
import
fused_layer_norm_affine
except
:
fused_layer_norm_affine
=
None
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
None
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
no_persist_layer_norm
=
True
,
sequence_parallel
=
False
,
apply_layernorm_1p
=
False
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
self
.
apply_layernorm_1p
=
apply_layernorm_1p
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes
=
[
1024
,
1536
,
2048
,
2304
,
3072
,
3840
,
4096
,
5120
,
6144
,
8192
,
10240
,
12288
,
12800
,
15360
,
16384
,
18432
,
20480
,
24576
,
25600
,
30720
,
32768
,
40960
,
49152
,
65536
]
if
normalized_shape
not
in
persist_ln_hidden_sizes
or
\
not
HAVE_PERSIST_LAYER_NORM
:
no_persist_layer_norm
=
True
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
eps
=
eps
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
reset_parameters
()
self
.
no_persist_layer_norm
=
no_persist_layer_norm
self
.
sequence_parallel
=
sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr
(
self
.
weight
,
'sequence_parallel'
,
self
.
sequence_parallel
)
setattr
(
self
.
bias
,
'sequence_parallel'
,
self
.
sequence_parallel
)
def
reset_parameters
(
self
):
if
self
.
apply_layernorm_1p
:
init
.
zeros_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
else
:
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
weight
=
self
.
weight
+
1
if
self
.
apply_layernorm_1p
else
self
.
weight
if
self
.
no_persist_layer_norm
:
assert
fused_layer_norm_affine
is
not
None
,
\
"fused_layer_norm_affine is not available, please install apex from https://github.com/NVIDIA/apex"
return
fused_layer_norm_affine
(
input
,
weight
,
self
.
bias
,
self
.
normalized_shape
,
eps
=
self
.
eps
)
else
:
output
=
FastLayerNormFN
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
eps
)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output
=
make_viewless_tensor
(
inp
=
output
,
requires_grad
=
input
.
requires_grad
,
keep_graph
=
True
)
return
output
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/fused_softmax.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
torch
import
torch.nn
as
nn
from
megatron.legacy.model.enums
import
AttnMaskType
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
mask
,
scale
):
import
scaled_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
,
None
class
ScaledSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Args:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
input_in_fp16
,
input_in_bf16
,
attn_mask_type
,
scaled_masked_softmax_fusion
,
mask_func
,
softmax_in_fp32
,
scale
,
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
assert
(
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
return
self
.
forward_fused_softmax
(
input
,
mask
)
else
:
return
self
.
forward_torch_softmax
(
input
,
mask
)
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
16
<
sk
<=
16384
# sk must be 16 ~ 16384
and
sq
%
4
==
0
# sq must be divisor of 4
and
sk
%
4
==
0
# sk must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
16384
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
def
forward_fused_softmax
(
self
,
input
,
mask
):
b
,
np
,
sq
,
sk
=
input
.
size
()
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
sq
==
sk
,
"causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input
=
input
.
view
(
-
1
,
sq
,
sk
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
if
mask
is
not
None
:
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
return
ScaledSoftmax
.
apply
(
input
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
@
staticmethod
def
get_batch_per_block
(
sq
,
sk
,
b
,
np
):
import
scaled_masked_softmax_cuda
return
scaled_masked_softmax_cuda
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/gpt_model.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""GPT-2 model."""
import
torch
from
megatron.training
import
get_args
from
megatron.core
import
tensor_parallel
from
.module
import
MegatronModule
from
.enums
import
AttnMaskType
from
.language_model
import
parallel_lm_logits
from
.language_model
import
get_language_model
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
parallel_output
,
fp16_lm_cross_entropy
):
# Output. Format [s b h]
output
=
parallel_lm_logits
(
lm_output
,
logit_weights
,
parallel_output
)
if
labels
is
None
:
# [s b h] => [b s h]
return
output
.
transpose
(
0
,
1
).
contiguous
()
else
:
# [b s] => [s b]
labels
=
labels
.
transpose
(
0
,
1
).
contiguous
()
if
fp16_lm_cross_entropy
:
assert
output
.
dtype
==
torch
.
half
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
,
labels
)
else
:
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
# [s b] => [b, s]
loss
=
loss
.
transpose
(
0
,
1
).
contiguous
()
return
loss
class
GPTModel
(
MegatronModule
):
"""GPT-2 Language model."""
def
__init__
(
self
,
config
,
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
super
().
__init__
(
config
=
config
,
share_embeddings_and_output_weights
=
not
args
.
untie_embeddings_and_output_weights
)
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
encoder_attn_mask_type
=
AttnMaskType
.
causal
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
if
not
args
.
untie_embeddings_and_output_weights
:
self
.
initialize_word_embeddings
()
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.legacy.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
retriever_input_ids
=
None
,
retriever_position_ids
=
None
,
retriever_attn_mask
=
None
,
labels
=
None
,
tokentype_ids
=
None
,
inference_params
=
None
):
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
attention_mask
,
retriever_input_ids
=
retriever_input_ids
,
retriever_position_ids
=
retriever_position_ids
,
retriever_attn_mask
=
retriever_attn_mask
,
inference_params
=
inference_params
)
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
labels
,
self
.
language_model
.
output_layer
.
weight
if
self
.
untie_embeddings_and_output_weights
else
self
.
shared_embedding_or_output_weight
(),
self
.
parallel_output
,
self
.
fp16_lm_cross_entropy
)
else
:
return
lm_output
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
and
not
self
.
untie_embeddings_and_output_weights
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
# Load word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
and
not
self
.
untie_embeddings_and_output_weights
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
if
self
.
_language_model_key
in
state_dict
:
state_dict
=
state_dict
[
self
.
_language_model_key
]
self
.
language_model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/language_model.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Transformer based language model."""
import
torch
import
torch.nn.functional
as
F
from
megatron.training
import
get_args
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
.enums
import
AttnMaskType
,
LayerType
from
.module
import
MegatronModule
from
.transformer
import
ParallelTransformer
from
.utils
import
get_linear_layer
from
.utils
import
init_method_normal
,
scaled_init_method_normal
def
parallel_lm_logits
(
input_
,
word_embeddings_weight
,
parallel_output
,
bias
=
None
):
"""LM logits using word embedding weights."""
args
=
get_args
()
# Parallel logits.
if
args
.
async_tensor_model_parallel_allreduce
or
\
args
.
sequence_parallel
:
input_parallel
=
input_
model_parallel
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
model_parallel
and
not
args
.
sequence_parallel
else
:
input_parallel
=
tensor_parallel
.
copy_to_tensor_model_parallel_region
(
input_
)
async_grad_allreduce
=
False
# Matrix multiply.
logits_parallel
=
tensor_parallel
.
linear_with_grad_accumulation_and_async_allreduce
(
input
=
input_parallel
,
weight
=
word_embeddings_weight
,
bias
=
bias
,
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
async_grad_allreduce
,
sequence_parallel
=
args
.
sequence_parallel
)
# Gather if needed.
if
parallel_output
:
return
logits_parallel
return
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits_parallel
)
def
get_language_model
(
config
,
num_tokentypes
,
add_pooler
,
encoder_attn_mask_type
,
add_encoder
=
True
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
pre_process
=
True
,
post_process
=
True
):
"""Build language model and return along with the key to save."""
args
=
get_args
()
if
config
.
init_method
is
None
:
config
.
init_method
=
init_method_normal
(
config
.
init_method_std
)
if
config
.
output_layer_init_method
is
None
:
config
.
output_layer_init_method
=
scaled_init_method_normal
(
config
.
init_method_std
,
config
.
num_layers
)
# Language model.
language_model
=
TransformerLanguageModel
(
config
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
,
decoder_attn_mask_type
=
decoder_attn_mask_type
,
add_pooler
=
add_pooler
,
pre_process
=
pre_process
,
post_process
=
post_process
)
# key used for checkpoints.
language_model_key
=
'language_model'
return
language_model
,
language_model_key
class
Pooler
(
MegatronModule
):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Args:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def
__init__
(
self
,
hidden_size
,
init_method
):
super
(
Pooler
,
self
).
__init__
()
args
=
get_args
()
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
sequence_parallel
=
args
.
sequence_parallel
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
# hidden_states: [s, b, h]
# sequence_index: index of the token to pool.
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if
self
.
sequence_parallel
:
hidden_states
=
tensor_parallel
.
gather_from_sequence_parallel_region
(
hidden_states
,
tensor_parallel_output_grad
=
False
)
pooled
=
hidden_states
[
sequence_index
,
:,
:]
pooled
=
self
.
dense
(
pooled
)
pooled
=
torch
.
tanh
(
pooled
)
return
pooled
class
Embedding
(
MegatronModule
):
"""Language model embeddings.
Args:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def
__init__
(
self
,
hidden_size
,
vocab_size
,
max_sequence_length
,
embedding_dropout_prob
,
config
,
num_tokentypes
=
0
):
super
(
Embedding
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
init_method
=
config
.
init_method
self
.
num_tokentypes
=
num_tokentypes
args
=
get_args
()
# Word embeddings (parallel).
self
.
params_dtype
=
args
.
params_dtype
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
vocab_size
,
self
.
hidden_size
,
config
=
config
,
init_method
=
config
.
init_method
)
self
.
_word_embeddings_key
=
'word_embeddings'
# Position embedding (serial).
self
.
add_position_embedding
=
args
.
position_embedding_type
==
'learned_absolute'
if
self
.
add_position_embedding
:
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
max_sequence_length
,
self
.
hidden_size
)
self
.
_position_embeddings_key
=
'position_embeddings'
# Initialize the position embeddings.
if
args
.
perform_initialization
:
self
.
init_method
(
self
.
position_embeddings
.
weight
)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self
.
_tokentype_embeddings_key
=
'tokentype_embeddings'
if
self
.
num_tokentypes
>
0
:
self
.
tokentype_embeddings
=
torch
.
nn
.
Embedding
(
self
.
num_tokentypes
,
self
.
hidden_size
)
# Initialize the token-type embeddings.
if
args
.
perform_initialization
:
self
.
init_method
(
self
.
tokentype_embeddings
.
weight
)
else
:
self
.
tokentype_embeddings
=
None
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
clone_scatter_output_in_embedding
=
args
.
clone_scatter_output_in_embedding
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
def
zero_parameters
(
self
):
"""Zero out all parameters in embedding."""
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
if
self
.
add_position_embedding
:
self
.
position_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
position_embeddings
.
weight
.
shared
=
True
if
self
.
num_tokentypes
>
0
:
self
.
tokentype_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
tokentype_embeddings
.
weight
.
shared
=
True
def
add_tokentype_embeddings
(
self
,
num_tokentypes
):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if
self
.
tokentype_embeddings
is
not
None
:
raise
Exception
(
'tokentype embeddings is already initialized'
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'adding embedding for {} tokentypes'
.
format
(
num_tokentypes
),
flush
=
True
)
self
.
num_tokentypes
=
num_tokentypes
self
.
tokentype_embeddings
=
torch
.
nn
.
Embedding
(
num_tokentypes
,
self
.
hidden_size
)
# Initialize the token-type embeddings.
args
=
get_args
()
self
.
init_method
(
self
.
tokentype_embeddings
.
weight
)
def
forward
(
self
,
input_ids
,
position_ids
,
tokentype_ids
=
None
):
# Embeddings.
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
add_position_embedding
:
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
words_embeddings
+
position_embeddings
else
:
embeddings
=
words_embeddings
if
tokentype_ids
is
not
None
:
assert
self
.
tokentype_embeddings
is
not
None
embeddings
=
embeddings
+
self
.
tokentype_embeddings
(
tokentype_ids
)
else
:
assert
self
.
tokentype_embeddings
is
None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
embeddings
=
embeddings
.
float
()
# Dropout.
if
self
.
sequence_parallel
:
embeddings
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
embeddings
)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if
self
.
clone_scatter_output_in_embedding
:
embeddings
=
embeddings
.
clone
()
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
embeddings
=
self
.
embedding_dropout
(
embeddings
)
return
embeddings
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load."""
state_dict_
=
{}
state_dict_
[
self
.
_word_embeddings_key
]
\
=
self
.
word_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
add_position_embedding
:
state_dict_
[
self
.
_position_embeddings_key
]
\
=
self
.
position_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
num_tokentypes
>
0
:
state_dict_
[
self
.
_tokentype_embeddings_key
]
\
=
self
.
tokentype_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
# Word embedding.
if
self
.
_word_embeddings_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_word_embeddings_key
]
else
:
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'word_embeddings'
in
key
:
state_dict_
[
key
.
split
(
'word_embeddings.'
)[
1
]]
\
=
state_dict
[
key
]
self
.
word_embeddings
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Position embedding.
if
self
.
add_position_embedding
:
if
self
.
_position_embeddings_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_position_embeddings_key
]
else
:
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'position_embeddings'
in
key
:
state_dict_
[
key
.
split
(
'position_embeddings.'
)[
1
]]
\
=
state_dict
[
key
]
self
.
position_embeddings
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Tokentype embedding.
if
self
.
num_tokentypes
>
0
:
state_dict_
=
{}
if
self
.
_tokentype_embeddings_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_tokentype_embeddings_key
]
else
:
# for backward compatibility.
for
key
in
state_dict
.
keys
():
if
'tokentype_embeddings'
in
key
:
state_dict_
[
key
.
split
(
'tokentype_embeddings.'
)[
1
]]
\
=
state_dict
[
key
]
if
len
(
state_dict_
.
keys
())
>
0
:
self
.
tokentype_embeddings
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
else
:
print
(
'***WARNING*** expected tokentype embeddings in the '
'checkpoint but could not find it'
,
flush
=
True
)
class
TransformerLanguageModel
(
MegatronModule
):
"""Transformer language model.
Args:
transformer_hparams: transformer hyperparameters
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def
__init__
(
self
,
config
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
add_encoder
=
True
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_pooler
=
False
,
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
# TODO: passing share_embeddings_and_output_weights=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
if
args
.
untie_embeddings_and_output_weights
:
assert
not
add_decoder
super
(
TransformerLanguageModel
,
self
).
__init__
(
share_embeddings_and_output_weights
=
not
args
.
untie_embeddings_and_output_weights
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
hidden_size
=
config
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
config
.
init_method
self
.
add_encoder
=
add_encoder
self
.
encoder_attn_mask_type
=
encoder_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
add_pooler
=
add_pooler
self
.
encoder_hidden_state
=
None
self
.
add_retriever
=
args
.
retro_add_retriever
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
# Embeddings.
if
self
.
pre_process
:
self
.
embedding
=
Embedding
(
self
.
hidden_size
,
args
.
padded_vocab_size
,
args
.
max_position_embeddings
,
args
.
hidden_dropout
,
config
,
self
.
num_tokentypes
)
self
.
_embedding_key
=
'embedding'
# Rotary positional embeddings
self
.
use_rotary_position_embeddings
=
\
args
.
position_embedding_type
==
'rope'
if
self
.
use_rotary_position_embeddings
:
self
.
seq_length
=
args
.
seq_length
rotary_dim
=
args
.
hidden_size
//
args
.
num_attention_heads
\
if
args
.
kv_channels
is
None
else
args
.
kv_channels
# partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al
# https://github.com/kingoflolz/mesh-transformer-jax/
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
rotary_dim
,
rotary_percent
=
args
.
rotary_percent
,
seq_len_interpolation_factor
=
args
.
rotary_seq_len_interpolation_factor
,
)
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if
self
.
add_encoder
:
self
.
encoder
=
ParallelTransformer
(
config
,
model_type
=
args
.
model_type
if
not
args
.
retro_add_retriever
\
else
ModelType
.
retro_decoder
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
self
.
_encoder_key
=
'encoder'
else
:
self
.
encoder
=
None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if
self
.
add_decoder
:
self
.
decoder
=
ParallelTransformer
(
config
,
model_type
=
args
.
model_type
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
self
.
decoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
_decoder_key
=
'decoder'
else
:
self
.
decoder
=
None
if
self
.
post_process
:
# Pooler.
if
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
if
self
.
untie_embeddings_and_output_weights
:
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
padded_vocab_size
,
config
=
config
,
init_method
=
self
.
init_method
,
bias
=
False
)
# Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
self
.
_output_layer_key
=
'output_layer'
def
set_input_tensor
(
self
,
input_tensor
):
""" See megatron.legacy.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
if
self
.
add_encoder
and
self
.
add_decoder
:
assert
len
(
input_tensor
)
==
1
,
\
'input_tensor should only be length 1 for stage with both encoder and decoder'
self
.
encoder
.
set_input_tensor
(
input_tensor
[
0
])
elif
self
.
add_encoder
:
assert
len
(
input_tensor
)
==
1
,
\
'input_tensor should only be length 1 for stage with only encoder'
self
.
encoder
.
set_input_tensor
(
input_tensor
[
0
])
elif
self
.
add_decoder
:
if
len
(
input_tensor
)
==
2
:
self
.
decoder
.
set_input_tensor
(
input_tensor
[
0
])
self
.
encoder_hidden_state
=
input_tensor
[
1
]
elif
len
(
input_tensor
)
==
1
:
self
.
decoder
.
set_input_tensor
(
None
)
self
.
encoder_hidden_state
=
input_tensor
[
0
]
else
:
raise
Exception
(
'input_tensor must have either length 1 or 2'
)
else
:
raise
Exception
(
'Stage must have at least either encoder or decoder'
)
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
retriever_input_ids
=
None
,
retriever_position_ids
=
None
,
retriever_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
inference_params
=
None
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Encoder embedding.
if
self
.
pre_process
:
encoder_input
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
tokentype_ids
=
tokentype_ids
)
else
:
encoder_input
=
None
# Retriever embedding.
if
self
.
add_retriever
and
self
.
pre_process
:
retriever_input
=
self
.
embedding
(
retriever_input_ids
,
retriever_position_ids
,
tokentype_ids
=
tokentype_ids
)
else
:
retriever_input
=
None
# Rotary positional embeddings
rotary_pos_emb
=
None
if
self
.
use_rotary_position_embeddings
:
if
inference_params
is
not
None
:
rotary_pos_emb
=
\
self
.
rotary_pos_emb
(
inference_params
.
max_sequence_length
)
else
:
rotary_pos_emb
=
self
.
rotary_pos_emb
(
self
.
seq_length
)
# Run encoder.
if
enc_hidden_states
is
None
:
if
self
.
encoder
is
not
None
:
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attn_mask
,
retriever_input
=
retriever_input
,
retriever_attn_mask
=
retriever_attn_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
)
else
:
encoder_output
=
self
.
encoder_hidden_state
else
:
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
if
self
.
post_process
:
if
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if
not
self
.
add_decoder
or
output_enc_hidden
:
if
self
.
add_pooler
and
self
.
post_process
:
return
encoder_output
,
pooled_output
else
:
return
encoder_output
# Decoder embedding.
if
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
else
:
decoder_input
=
None
# Run decoder.
decoder_output
=
self
.
decoder
(
decoder_input
,
dec_attn_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
)
if
self
.
add_pooler
and
self
.
post_process
:
return
decoder_output
,
encoder_output
,
pooled_output
else
:
return
decoder_output
,
encoder_output
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load."""
state_dict_
=
{}
if
self
.
pre_process
:
state_dict_
[
self
.
_embedding_key
]
\
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
add_encoder
:
state_dict_
[
self
.
_encoder_key
]
\
=
self
.
encoder
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
post_process
:
if
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
\
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
untie_embeddings_and_output_weights
:
state_dict_
[
self
.
_output_layer_key
]
\
=
self
.
output_layer
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
add_decoder
:
state_dict_
[
self
.
_decoder_key
]
\
=
self
.
decoder
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
# Embedding.
if
self
.
pre_process
:
if
self
.
_embedding_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_embedding_key
]
else
:
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'_embeddings'
in
key
:
state_dict_
[
key
]
=
state_dict
[
key
]
self
.
embedding
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Encoder.
if
self
.
add_encoder
:
if
self
.
_encoder_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_encoder_key
]
# For backward compatibility.
elif
'transformer'
in
state_dict
:
state_dict_
=
state_dict
[
'transformer'
]
else
:
# For backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'transformer.'
in
key
:
state_dict_
[
key
.
split
(
'transformer.'
)[
1
]]
=
state_dict
[
key
]
# For backward compatibility.
state_dict_self_attention
=
{}
for
key
in
state_dict_
.
keys
():
if
'.attention.'
in
key
:
state_dict_self_attention
[
key
.
replace
(
".attention."
,
".self_attention."
)]
=
state_dict_
[
key
]
else
:
state_dict_self_attention
[
key
]
=
state_dict_
[
key
]
state_dict_
=
state_dict_self_attention
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Pooler.
if
self
.
post_process
:
if
self
.
add_pooler
:
assert
'pooler'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
if
self
.
untie_embeddings_and_output_weights
:
assert
'output_layer'
in
state_dict
,
\
'could not find data for output_layer in the checkpoint'
self
.
output_layer
.
load_state_dict
(
state_dict
[
self
.
_output_layer_key
],
strict
=
strict
)
# Decoder.
if
self
.
add_decoder
:
assert
'decoder'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
decoder
.
load_state_dict
(
state_dict
[
self
.
_decoder_key
],
strict
=
strict
)
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/module.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron Module"""
import
torch
from
torch.autograd
import
Variable
from
torch.nn.parameter
import
Parameter
from
megatron.training
import
get_args
from
megatron.core
import
mpu
,
tensor_parallel
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
_HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
_BF16_TYPES
=
(
torch
.
BFloat16Tensor
,
torch
.
cuda
.
BFloat16Tensor
)
def
param_is_not_shared
(
param
):
return
not
hasattr
(
param
,
'shared'
)
or
not
param
.
shared
class
MegatronModule
(
torch
.
nn
.
Module
):
"""Megatron specific extensions of torch Module with support
for pipelining."""
def
__init__
(
self
,
config
=
None
,
share_embeddings_and_output_weights
=
True
):
super
(
MegatronModule
,
self
).
__init__
()
self
.
config
=
config
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""Use this function to override the state dict for
saving checkpoints."""
return
self
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
shared_embedding_or_output_weight
(
self
):
if
self
.
pre_process
:
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
else
:
if
not
self
.
share_embeddings_and_output_weights
:
raise
Exception
(
'shared_embedding_or_output_weight() called for last '
'stage, but share_embeddings_and_output_weights is false'
)
return
self
.
word_embeddings
.
weight
def
initialize_word_embeddings
(
self
):
args
=
get_args
()
if
not
self
.
share_embeddings_and_output_weights
:
raise
Exception
(
'initialize_word_embeddings() was called but '
'share_embeddings_and_output_weights is false'
)
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
# using pipeline parallelism.
if
args
.
pipeline_model_parallel_size
==
1
:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self
.
shared_embedding_or_output_weight
().
zero_out_wgrad
=
True
return
if
mpu
.
is_pipeline_first_stage
()
and
self
.
pre_process
and
not
self
.
post_process
:
self
.
shared_embedding_or_output_weight
().
shared_embedding
=
True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if
mpu
.
is_pipeline_last_stage
()
and
not
self
.
pre_process
:
assert
not
mpu
.
is_pipeline_first_stage
()
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
args
.
padded_vocab_size
,
self
.
config
.
hidden_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
)
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
self
.
word_embeddings
.
weight
.
shared_embedding
=
True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if
not
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
and
\
self
.
pre_process
:
self
.
language_model
.
embedding
.
zero_parameters
()
if
not
torch
.
distributed
.
is_initialized
():
if
not
getattr
(
MegatronModule
,
"embedding_warning_printed"
,
False
):
print
(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
MegatronModule
.
embedding_warning_printed
=
True
return
# Ensure that first and last stages have the same initial parameter
# values.
if
mpu
.
is_rank_in_embedding_group
():
self
.
shared_embedding_or_output_weight
().
data
=
self
.
shared_embedding_or_output_weight
().
data
.
cuda
()
torch
.
distributed
.
all_reduce
(
self
.
shared_embedding_or_output_weight
().
data
,
group
=
mpu
.
get_embedding_group
())
# Ensure that encoder(first stage) and decoder(split stage) position
# embeddings have the same initial parameter values
# NOTE: We don't currently support T5 with the interleaved schedule.
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
# TODO: Support tokentype embedding.
self
.
language_model
.
embedding
.
cuda
()
position_embeddings
=
self
.
language_model
.
embedding
.
position_embeddings
torch
.
distributed
.
all_reduce
(
position_embeddings
.
weight
.
data
,
group
=
mpu
.
get_position_embedding_group
())
def
conversion_helper
(
val
,
conversion
):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
if
not
isinstance
(
val
,
(
tuple
,
list
)):
return
conversion
(
val
)
rtn
=
[
conversion_helper
(
v
,
conversion
)
for
v
in
val
]
if
isinstance
(
val
,
tuple
):
rtn
=
tuple
(
rtn
)
return
rtn
def
fp32_to_float16
(
val
,
float16_convertor
):
"""Convert fp32 `val` to fp16/bf16"""
def
half_conversion
(
val
):
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
_FLOAT_TYPES
):
val
=
float16_convertor
(
val
)
return
val
return
conversion_helper
(
val
,
half_conversion
)
def
float16_to_fp32
(
val
):
"""Convert fp16/bf16 `val` to fp32"""
def
float_conversion
(
val
):
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
(
_BF16_TYPES
,
_HALF_TYPES
)):
val
=
val
.
float
()
return
val
return
conversion_helper
(
val
,
float_conversion
)
class
Float16Module
(
MegatronModule
):
def
__init__
(
self
,
module
,
args
):
super
(
Float16Module
,
self
).
__init__
()
if
args
.
fp16
:
self
.
add_module
(
'module'
,
module
.
half
())
def
float16_convertor
(
val
):
return
val
.
half
()
elif
args
.
bf16
:
self
.
add_module
(
'module'
,
module
.
bfloat16
())
def
float16_convertor
(
val
):
return
val
.
bfloat16
()
else
:
raise
Exception
(
'should not be here'
)
self
.
float16_convertor
=
float16_convertor
def
set_input_tensor
(
self
,
input_tensor
):
return
self
.
module
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
mpu
.
is_pipeline_first_stage
():
inputs
=
fp32_to_float16
(
inputs
,
self
.
float16_convertor
)
outputs
=
self
.
module
(
*
inputs
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
():
outputs
=
float16_to_fp32
(
outputs
)
return
outputs
def
state_dict
(
self
,
prefix
=
''
,
keep_vars
=
False
):
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
return
self
.
module
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/multiple_choice.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Multiple choice model."""
import
torch
from
megatron.training
import
get_args
,
print_rank_last
from
megatron.legacy.model.enums
import
AttnMaskType
from
megatron.legacy.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
from
megatron.legacy.model.language_model
import
get_language_model
from
megatron.legacy.model.utils
import
get_linear_layer
from
megatron.legacy.model.utils
import
init_method_normal
from
megatron.legacy.model.utils
import
scaled_init_method_normal
from
.module
import
MegatronModule
class
MultipleChoice
(
MegatronModule
):
def
__init__
(
self
,
config
,
num_tokentypes
=
2
,
pre_process
=
True
,
post_process
=
True
):
super
(
MultipleChoice
,
self
).
__init__
(
share_embeddings_and_output_weights
=
False
)
args
=
get_args
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
# Multi-choice head.
if
self
.
post_process
:
self
.
multichoice_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
multichoice_head
=
get_linear_layer
(
args
.
hidden_size
,
1
,
init_method
)
self
.
_multichoice_head_key
=
'multichoice_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.legacy.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
model_input
,
attention_mask
,
tokentype_ids
=
None
):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
# transformer --> [batch, choices] --> softmax
# Ensure the shape is [batch-size, choices, sequence]
assert
len
(
attention_mask
.
shape
)
==
3
num_choices
=
attention_mask
.
shape
[
1
]
# Reshape and treat choice dimension the same as batch.
attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
input_ids
=
model_input
# Do the same as attention_mask for input_ids, tokentype_ids
assert
len
(
input_ids
.
shape
)
==
3
assert
len
(
tokentype_ids
.
shape
)
==
3
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
tokentype_ids
=
tokentype_ids
.
view
(
-
1
,
tokentype_ids
.
size
(
-
1
))
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
if
self
.
post_process
:
_
,
pooled_output
=
lm_output
multichoice_output
=
self
.
multichoice_dropout
(
pooled_output
)
multichoice_logits
=
self
.
multichoice_head
(
multichoice_output
)
# Reshape back to separate choices.
multichoice_logits
=
multichoice_logits
.
view
(
-
1
,
num_choices
)
return
multichoice_logits
return
lm_output
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
post_process
:
state_dict_
[
self
.
_multichoice_head_key
]
\
=
self
.
multichoice_head
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
self
.
post_process
:
if
self
.
_multichoice_head_key
in
state_dict
:
self
.
multichoice_head
.
load_state_dict
(
state_dict
[
self
.
_multichoice_head_key
],
strict
=
strict
)
else
:
print_rank_last
(
'***WARNING*** could not find {} in the checkpoint, '
'initializing to random'
.
format
(
self
.
_multichoice_head_key
))
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/realm_model.py
0 → 100644
View file @
bc5c7fa7
import
os
import
torch
from
megatron.training
import
get_args
,
print_rank_0
from
megatron.training.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.legacy.model
import
BertModel
from
.module
import
MegatronModule
from
megatron.core
import
mpu
from
megatron.legacy.model.enums
import
AttnMaskType
from
megatron.legacy.model.utils
import
get_linear_layer
from
megatron.legacy.model.utils
import
init_method_normal
from
megatron.legacy.model.language_model
import
get_language_model
from
megatron.legacy.model.utils
import
scaled_init_method_normal
from
megatron.legacy.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
def
general_ict_model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
"""Build the model."""
args
=
get_args
()
assert
args
.
ict_head_size
is
not
None
,
\
"Need to specify --ict-head-size to provide an ICTBertModel"
assert
mpu
.
get_tensor_model_parallel_world_size
()
==
1
and
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
,
\
"Model parallel size > 1 not supported for ICT"
print_rank_0
(
'building ICTBertModel...'
)
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model
=
ICTBertModel
(
ict_head_size
=
args
.
ict_head_size
,
num_tokentypes
=
2
,
parallel_output
=
True
,
only_query_model
=
only_query_model
,
only_block_model
=
only_block_model
)
return
model
class
ICTBertModel
(
MegatronModule
):
"""Bert-based module for Inverse Cloze task."""
def
__init__
(
self
,
ict_head_size
,
num_tokentypes
=
1
,
parallel_output
=
True
,
only_query_model
=
False
,
only_block_model
=
False
):
super
(
ICTBertModel
,
self
).
__init__
()
bert_kwargs
=
dict
(
ict_head_size
=
ict_head_size
,
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
assert
not
(
only_block_model
and
only_query_model
)
self
.
use_block_model
=
not
only_query_model
self
.
use_query_model
=
not
only_block_model
if
self
.
use_query_model
:
# this model embeds (pseudo-)queries - Embed_input in the paper
self
.
query_model
=
IREncoderBertModel
(
**
bert_kwargs
)
self
.
_query_key
=
'question_model'
if
self
.
use_block_model
:
# this model embeds evidence blocks - Embed_doc in the paper
self
.
block_model
=
IREncoderBertModel
(
**
bert_kwargs
)
self
.
_block_key
=
'context_model'
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
block_tokens
,
block_attention_mask
):
"""Run a forward pass for each of the models and return the respective embeddings."""
query_logits
=
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
return
query_logits
,
block_logits
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
"""Embed a batch of tokens using the query model"""
if
self
.
use_query_model
:
query_types
=
torch
.
cuda
.
LongTensor
(
*
query_tokens
.
shape
).
fill_
(
0
)
query_ict_logits
,
_
=
self
.
query_model
.
forward
(
query_tokens
,
query_attention_mask
,
query_types
)
return
query_ict_logits
else
:
raise
ValueError
(
"Cannot embed query without query model."
)
def
embed_block
(
self
,
block_tokens
,
block_attention_mask
):
"""Embed a batch of tokens using the block model"""
if
self
.
use_block_model
:
block_types
=
torch
.
cuda
.
LongTensor
(
*
block_tokens
.
shape
).
fill_
(
0
)
block_ict_logits
,
_
=
self
.
block_model
.
forward
(
block_tokens
,
block_attention_mask
,
block_types
)
return
block_ict_logits
else
:
raise
ValueError
(
"Cannot embed block without block model."
)
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
if
self
.
use_query_model
:
state_dict_
[
self
.
_query_key
]
\
=
self
.
query_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
use_block_model
:
state_dict_
[
self
.
_block_key
]
\
=
self
.
block_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
if
self
.
use_query_model
:
print
(
"Loading ICT query model"
,
flush
=
True
)
self
.
query_model
.
load_state_dict
(
state_dict
[
self
.
_query_key
],
strict
=
strict
)
if
self
.
use_block_model
:
print
(
"Loading ICT block model"
,
flush
=
True
)
self
.
block_model
.
load_state_dict
(
state_dict
[
self
.
_block_key
],
strict
=
strict
)
def
init_state_dict_from_bert
(
self
):
"""Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining"""
args
=
get_args
()
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
bert_load
)
if
not
os
.
path
.
isfile
(
tracker_filename
):
raise
FileNotFoundError
(
"Could not find BERT load for ICT"
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
checkpoint_name
=
get_checkpoint_name
(
args
.
bert_load
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
try
:
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
except
BaseException
:
raise
ValueError
(
"Could not load checkpoint"
)
# load the LM state dict into each model
model_dict
=
state_dict
[
'model'
][
'language_model'
]
self
.
query_model
.
language_model
.
load_state_dict
(
model_dict
)
self
.
block_model
.
language_model
.
load_state_dict
(
model_dict
)
# give each model the same ict_head to begin with as well
query_ict_head_state_dict
=
self
.
state_dict_for_save_checkpoint
()[
self
.
_query_key
][
'ict_head'
]
self
.
block_model
.
ict_head
.
load_state_dict
(
query_ict_head_state_dict
)
class
IREncoderBertModel
(
MegatronModule
):
"""BERT-based encoder for queries or blocks used for learned information retrieval."""
def
__init__
(
self
,
ict_head_size
,
num_tokentypes
=
2
,
parallel_output
=
True
):
super
(
IREncoderBertModel
,
self
).
__init__
()
args
=
get_args
()
self
.
ict_head_size
=
ict_head_size
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
self
.
ict_head
=
get_linear_layer
(
args
.
hidden_size
,
ict_head_size
,
init_method
)
self
.
_ict_head_key
=
'ict_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
,
next
(
self
.
language_model
.
parameters
()).
dtype
)
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
,
pooled_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
# Output.
ict_logits
=
self
.
ict_head
(
pooled_output
)
return
ict_logits
,
None
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
state_dict_
[
self
.
_ict_head_key
]
\
=
self
.
ict_head
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
self
.
ict_head
.
load_state_dict
(
state_dict
[
self
.
_ict_head_key
],
strict
=
strict
)
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment