Commit f772fbc9 authored by Jared Casper's avatar Jared Casper
Browse files

Only create task heads on last pipeline stage.

parent 6fa36844
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
...@@ -45,11 +45,12 @@ class ClassificationBase(PipelinedMegatronModule): ...@@ -45,11 +45,12 @@ class ClassificationBase(PipelinedMegatronModule):
args.num_layers)) args.num_layers))
# Multi-choice head. # Multi-choice head.
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) if mpu.is_pipeline_last_stage():
self.classification_head = get_linear_layer(args.hidden_size, self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.num_classes, self.classification_head = get_linear_layer(args.hidden_size,
init_method) self.num_classes,
self._classification_head_key = 'classification_head' init_method)
self._classification_head_key = 'classification_head'
def forward(self, model_input, attention_mask, tokentype_ids=None): def forward(self, model_input, attention_mask, tokentype_ids=None):
...@@ -85,9 +86,10 @@ class ClassificationBase(PipelinedMegatronModule): ...@@ -85,9 +86,10 @@ class ClassificationBase(PipelinedMegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._classification_head_key] \ if mpu.is_pipeline_last_stage():
= self.classification_head.state_dict( state_dict_[self._classification_head_key] \
destination, prefix, keep_vars) = self.classification_head.state_dict(
destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
...@@ -95,13 +97,14 @@ class ClassificationBase(PipelinedMegatronModule): ...@@ -95,13 +97,14 @@ class ClassificationBase(PipelinedMegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if self._classification_head_key in state_dict: if mpu.is_pipeline_last_stage():
self.classification_head.load_state_dict( if self._classification_head_key in state_dict:
state_dict[self._classification_head_key], strict=strict) self.classification_head.load_state_dict(
else: state_dict[self._classification_head_key], strict=strict)
print_rank_0('***WARNING*** could not find {} in the checkpoint, ' else:
'initializing to random'.format( print_rank_last('***WARNING*** could not find {} in the checkpoint, '
self._classification_head_key)) 'initializing to random'.format(
self._classification_head_key))
class Classification(ClassificationBase): class Classification(ClassificationBase):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
...@@ -44,10 +44,11 @@ class MultipleChoiceBase(PipelinedMegatronModule): ...@@ -44,10 +44,11 @@ class MultipleChoiceBase(PipelinedMegatronModule):
args.num_layers)) args.num_layers))
# Multi-choice head. # Multi-choice head.
self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) if mpu.is_pipeline_last_stage():
self.multichoice_head = get_linear_layer(args.hidden_size, 1, self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
init_method) self.multichoice_head = get_linear_layer(args.hidden_size, 1,
self._multichoice_head_key = 'multichoice_head' init_method)
self._multichoice_head_key = 'multichoice_head'
def forward(self, model_input, attention_mask, tokentype_ids=None): def forward(self, model_input, attention_mask, tokentype_ids=None):
...@@ -97,9 +98,10 @@ class MultipleChoiceBase(PipelinedMegatronModule): ...@@ -97,9 +98,10 @@ class MultipleChoiceBase(PipelinedMegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._multichoice_head_key] \ if mpu.is_pipeline_last_stage():
= self.multichoice_head.state_dict( state_dict_[self._multichoice_head_key] \
destination, prefix, keep_vars) = self.multichoice_head.state_dict(
destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
...@@ -107,13 +109,14 @@ class MultipleChoiceBase(PipelinedMegatronModule): ...@@ -107,13 +109,14 @@ class MultipleChoiceBase(PipelinedMegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if self._multichoice_head_key in state_dict: if mpu.is_pipeline_last_stage():
self.multichoice_head.load_state_dict( if self._multichoice_head_key in state_dict:
state_dict[self._multichoice_head_key], strict=strict) self.multichoice_head.load_state_dict(
else: state_dict[self._multichoice_head_key], strict=strict)
print_rank_0('***WARNING*** could not find {} in the checkpoint, ' else:
'initializing to random'.format( print_rank_last('***WARNING*** could not find {} in the checkpoint, '
self._multichoice_head_key)) 'initializing to random'.format(
self._multichoice_head_key))
class MultipleChoice(MultipleChoiceBase): class MultipleChoice(MultipleChoiceBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment