Commit f772fbc9 authored by Jared Casper's avatar Jared Casper
Browse files

Only create task heads on last pipeline stage.

parent 6fa36844
......@@ -17,7 +17,7 @@
import torch
from megatron import get_args, print_rank_0
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
......@@ -45,11 +45,12 @@ class ClassificationBase(PipelinedMegatronModule):
args.num_layers))
# Multi-choice head.
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes,
init_method)
self._classification_head_key = 'classification_head'
if mpu.is_pipeline_last_stage():
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes,
init_method)
self._classification_head_key = 'classification_head'
def forward(self, model_input, attention_mask, tokentype_ids=None):
......@@ -85,9 +86,10 @@ class ClassificationBase(PipelinedMegatronModule):
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._classification_head_key] \
= self.classification_head.state_dict(
destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage():
state_dict_[self._classification_head_key] \
= self.classification_head.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......@@ -95,13 +97,14 @@ class ClassificationBase(PipelinedMegatronModule):
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self._classification_head_key in state_dict:
self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict)
else:
print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._classification_head_key))
if mpu.is_pipeline_last_stage():
if self._classification_head_key in state_dict:
self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict)
else:
print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._classification_head_key))
class Classification(ClassificationBase):
......
......@@ -17,7 +17,7 @@
import torch
from megatron import get_args, print_rank_0
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
......@@ -44,10 +44,11 @@ class MultipleChoiceBase(PipelinedMegatronModule):
args.num_layers))
# Multi-choice head.
self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
self.multichoice_head = get_linear_layer(args.hidden_size, 1,
init_method)
self._multichoice_head_key = 'multichoice_head'
if mpu.is_pipeline_last_stage():
self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
self.multichoice_head = get_linear_layer(args.hidden_size, 1,
init_method)
self._multichoice_head_key = 'multichoice_head'
def forward(self, model_input, attention_mask, tokentype_ids=None):
......@@ -97,9 +98,10 @@ class MultipleChoiceBase(PipelinedMegatronModule):
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict(
destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage():
state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......@@ -107,13 +109,14 @@ class MultipleChoiceBase(PipelinedMegatronModule):
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self._multichoice_head_key in state_dict:
self.multichoice_head.load_state_dict(
state_dict[self._multichoice_head_key], strict=strict)
else:
print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._multichoice_head_key))
if mpu.is_pipeline_last_stage():
if self._multichoice_head_key in state_dict:
self.multichoice_head.load_state_dict(
state_dict[self._multichoice_head_key], strict=strict)
else:
print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._multichoice_head_key))
class MultipleChoice(MultipleChoiceBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment