Commit f772fbc9 authored by Jared Casper's avatar Jared Casper
Browse files

Only create task heads on last pipeline stage.

parent 6fa36844
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
...@@ -45,6 +45,7 @@ class ClassificationBase(PipelinedMegatronModule): ...@@ -45,6 +45,7 @@ class ClassificationBase(PipelinedMegatronModule):
args.num_layers)) args.num_layers))
# Multi-choice head. # Multi-choice head.
if mpu.is_pipeline_last_stage():
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(args.hidden_size, self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes, self.num_classes,
...@@ -85,6 +86,7 @@ class ClassificationBase(PipelinedMegatronModule): ...@@ -85,6 +86,7 @@ class ClassificationBase(PipelinedMegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage():
state_dict_[self._classification_head_key] \ state_dict_[self._classification_head_key] \
= self.classification_head.state_dict( = self.classification_head.state_dict(
destination, prefix, keep_vars) destination, prefix, keep_vars)
...@@ -95,11 +97,12 @@ class ClassificationBase(PipelinedMegatronModule): ...@@ -95,11 +97,12 @@ class ClassificationBase(PipelinedMegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if mpu.is_pipeline_last_stage():
if self._classification_head_key in state_dict: if self._classification_head_key in state_dict:
self.classification_head.load_state_dict( self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict) state_dict[self._classification_head_key], strict=strict)
else: else:
print_rank_0('***WARNING*** could not find {} in the checkpoint, ' print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format( 'initializing to random'.format(
self._classification_head_key)) self._classification_head_key))
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
...@@ -44,6 +44,7 @@ class MultipleChoiceBase(PipelinedMegatronModule): ...@@ -44,6 +44,7 @@ class MultipleChoiceBase(PipelinedMegatronModule):
args.num_layers)) args.num_layers))
# Multi-choice head. # Multi-choice head.
if mpu.is_pipeline_last_stage():
self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
self.multichoice_head = get_linear_layer(args.hidden_size, 1, self.multichoice_head = get_linear_layer(args.hidden_size, 1,
init_method) init_method)
...@@ -97,6 +98,7 @@ class MultipleChoiceBase(PipelinedMegatronModule): ...@@ -97,6 +98,7 @@ class MultipleChoiceBase(PipelinedMegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage():
state_dict_[self._multichoice_head_key] \ state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict( = self.multichoice_head.state_dict(
destination, prefix, keep_vars) destination, prefix, keep_vars)
...@@ -107,11 +109,12 @@ class MultipleChoiceBase(PipelinedMegatronModule): ...@@ -107,11 +109,12 @@ class MultipleChoiceBase(PipelinedMegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if mpu.is_pipeline_last_stage():
if self._multichoice_head_key in state_dict: if self._multichoice_head_key in state_dict:
self.multichoice_head.load_state_dict( self.multichoice_head.load_state_dict(
state_dict[self._multichoice_head_key], strict=strict) state_dict[self._multichoice_head_key], strict=strict)
else: else:
print_rank_0('***WARNING*** could not find {} in the checkpoint, ' print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format( 'initializing to random'.format(
self._multichoice_head_key)) self._multichoice_head_key))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment