Commit aed2f75e authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into github-main

parents 8aa4619f f32a638d
......@@ -19,7 +19,8 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
......@@ -27,46 +28,57 @@ from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
class ClassificationBase(MegatronModule):
class Classification(MegatronModule):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationBase, self).__init__(share_word_embeddings=False)
def __init__(self,
num_classes,
num_tokentypes=2,
pre_process=True,
post_process=True):
super(Classification, self).__init__(share_word_embeddings=False)
args = get_args()
self.num_classes = num_classes
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
# Multi-choice head.
if mpu.is_pipeline_last_stage():
if self.post_process:
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes,
init_method)
self._classification_head_key = 'classification_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
kwargs = {}
if mpu.is_pipeline_first_stage():
input_ids = model_input
position_ids = bert_position_ids(input_ids)
args = [input_ids, position_ids, extended_attention_mask]
kwargs['tokentype_ids'] = tokentype_ids
else:
args = [model_input, extended_attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage():
if self.post_process:
_, pooled_output = lm_output
classification_output = self.classification_dropout(pooled_output)
classification_logits = self.classification_head(classification_output)
......@@ -86,7 +98,7 @@ class ClassificationBase(MegatronModule):
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage():
if self.post_process:
state_dict_[self._classification_head_key] \
= self.classification_head.state_dict(
destination, prefix, keep_vars)
......@@ -97,7 +109,7 @@ class ClassificationBase(MegatronModule):
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if mpu.is_pipeline_last_stage():
if self.post_process:
if self._classification_head_key in state_dict:
self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict)
......@@ -105,55 +117,3 @@ class ClassificationBase(MegatronModule):
print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._classification_head_key))
class Classification(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(Classification, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(Classification, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class ClassificationFirstStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationFirstStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(ClassificationFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class ClassificationIntermediateStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationIntermediateStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(ClassificationIntermediateStage, self).forward(
hidden_state,
attention_mask)
class ClassificationLastStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationLastStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(ClassificationLastStage, self).forward(
hidden_state,
attention_mask)
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -6,11 +6,12 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
from megatron.model import BertModel
from .module import MegatronModule
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False):
......@@ -156,9 +157,9 @@ class IREncoderBertModel(MegatronModule):
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
......
This diff is collapsed.
......@@ -39,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers):
return init_
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
......
This diff is collapsed.
This diff is collapsed.
......@@ -20,7 +20,7 @@ from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 4
_MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment