Unverified Commit ae71036c authored by ver217's avatar ver217 Committed by GitHub
Browse files

[utils] refactor parallel layers checkpoint and bcast model on loading checkpoint (#1548)

* refactor parallel layer

* broadcast rank0 model after load ckpt
parent 2bed0968
...@@ -5,9 +5,11 @@ import torch.nn as nn ...@@ -5,9 +5,11 @@ import torch.nn as nn
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from contextlib import contextmanager
class ParallelLayer(nn.Module): class ParallelLayer(nn.Module):
global_state_dict: bool = True
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -26,10 +28,35 @@ class ParallelLayer(nn.Module): ...@@ -26,10 +28,35 @@ class ParallelLayer(nn.Module):
self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE) ParallelMode.PIPELINE)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, def _load_from_global_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs): error_msgs):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs) error_msgs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
return super()._save_to_state_dict(destination, prefix, keep_vars)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
if self.global_state_dict:
if gpc.get_local_rank(ParallelMode.TENSOR) != 0: if gpc.get_local_rank(ParallelMode.TENSOR) != 0:
missing_keys.clear() missing_keys.clear()
unexpected_keys.clear() unexpected_keys.clear()
return self._load_from_global_state_dict(state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if self.global_state_dict:
return self._save_to_global_state_dict(destination, prefix, keep_vars)
return super()._save_to_state_dict(destination, prefix, keep_vars)
@classmethod
@contextmanager
def use_local_state_dict(cls):
try:
cls.global_state_dict = False
yield
finally:
cls.global_state_dict = True
...@@ -189,7 +189,7 @@ class Classifier1D(ParallelLayer): ...@@ -189,7 +189,7 @@ class Classifier1D(ParallelLayer):
num_partition = gpc.get_world_size(ParallelMode.TENSOR) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition) set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args): def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -215,9 +215,9 @@ class Classifier1D(ParallelLayer): ...@@ -215,9 +215,9 @@ class Classifier1D(ParallelLayer):
weight_key: True, weight_key: True,
bias_key: False bias_key: False
}) })
super()._load_from_state_dict(local_state, prefix, *args) super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict() local_state = OrderedDict()
...@@ -326,7 +326,7 @@ class VocabParallelClassifier1D(ParallelLayer): ...@@ -326,7 +326,7 @@ class VocabParallelClassifier1D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition) set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args): def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -352,9 +352,9 @@ class VocabParallelClassifier1D(ParallelLayer): ...@@ -352,9 +352,9 @@ class VocabParallelClassifier1D(ParallelLayer):
weight_key: True, weight_key: True,
bias_key: True bias_key: True
}) })
super()._load_from_state_dict(local_state, prefix, *args) super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict() local_state = OrderedDict()
...@@ -461,7 +461,7 @@ class Linear1D_Col(ParallelLayer): ...@@ -461,7 +461,7 @@ class Linear1D_Col(ParallelLayer):
if self.bias is not None: if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition) set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args): def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -486,9 +486,9 @@ class Linear1D_Col(ParallelLayer): ...@@ -486,9 +486,9 @@ class Linear1D_Col(ParallelLayer):
weight_key: True, weight_key: True,
bias_key: True bias_key: True
}) })
super()._load_from_state_dict(local_state, prefix, *args) super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -598,7 +598,7 @@ class Linear1D_Row(ParallelLayer): ...@@ -598,7 +598,7 @@ class Linear1D_Row(ParallelLayer):
num_partition = gpc.get_world_size(ParallelMode.TENSOR) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition) set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args): def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -623,9 +623,9 @@ class Linear1D_Row(ParallelLayer): ...@@ -623,9 +623,9 @@ class Linear1D_Row(ParallelLayer):
weight_key: True, weight_key: True,
bias_key: False bias_key: False
}) })
super()._load_from_state_dict(local_state, prefix, *args) super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -738,7 +738,7 @@ class Embedding1D(ParallelLayer): ...@@ -738,7 +738,7 @@ class Embedding1D(ParallelLayer):
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args): def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
...@@ -751,9 +751,9 @@ class Embedding1D(ParallelLayer): ...@@ -751,9 +751,9 @@ class Embedding1D(ParallelLayer):
ParallelMode.PARALLEL_1D, ParallelMode.PARALLEL_1D,
dims={weight_key: -1}, dims={weight_key: -1},
partition_states={weight_key: True}) partition_states={weight_key: True})
super()._load_from_state_dict(local_state, prefix, *args) super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state, local_state = gather_tensor_parallel_state_dict(local_state,
...@@ -773,7 +773,7 @@ class Embedding1D(ParallelLayer): ...@@ -773,7 +773,7 @@ class Embedding1D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class VocabParallelEmbedding1D(torch.nn.Module): class VocabParallelEmbedding1D(ParallelLayer):
r"""Embedding parallelized in the vocabulary dimension. r"""Embedding parallelized in the vocabulary dimension.
Args: Args:
...@@ -847,7 +847,7 @@ class VocabParallelEmbedding1D(torch.nn.Module): ...@@ -847,7 +847,7 @@ class VocabParallelEmbedding1D(torch.nn.Module):
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0) self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args): def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
...@@ -860,9 +860,9 @@ class VocabParallelEmbedding1D(torch.nn.Module): ...@@ -860,9 +860,9 @@ class VocabParallelEmbedding1D(torch.nn.Module):
ParallelMode.PARALLEL_1D, ParallelMode.PARALLEL_1D,
dims={weight_key: 0}, dims={weight_key: 0},
partition_states={weight_key: True}) partition_states={weight_key: True})
super()._load_from_state_dict(local_state, prefix, *args) super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state, local_state = gather_tensor_parallel_state_dict(local_state,
......
...@@ -94,7 +94,7 @@ class Linear2D(ParallelLayer): ...@@ -94,7 +94,7 @@ class Linear2D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -137,9 +137,9 @@ class Linear2D(ParallelLayer): ...@@ -137,9 +137,9 @@ class Linear2D(ParallelLayer):
}, },
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -252,7 +252,7 @@ class LayerNorm2D(ParallelLayer): ...@@ -252,7 +252,7 @@ class LayerNorm2D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -294,9 +294,9 @@ class LayerNorm2D(ParallelLayer): ...@@ -294,9 +294,9 @@ class LayerNorm2D(ParallelLayer):
}, },
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -443,7 +443,7 @@ class PatchEmbedding2D(ParallelLayer): ...@@ -443,7 +443,7 @@ class PatchEmbedding2D(ParallelLayer):
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed) position_embed_initializer(self.pos_embed)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -503,9 +503,9 @@ class PatchEmbedding2D(ParallelLayer): ...@@ -503,9 +503,9 @@ class PatchEmbedding2D(ParallelLayer):
}, },
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token' cls_token_key = prefix + 'cls_token'
...@@ -651,7 +651,7 @@ class Embedding2D(ParallelLayer): ...@@ -651,7 +651,7 @@ class Embedding2D(ParallelLayer):
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
...@@ -676,9 +676,9 @@ class Embedding2D(ParallelLayer): ...@@ -676,9 +676,9 @@ class Embedding2D(ParallelLayer):
partition_states={weight_key: True}, partition_states={weight_key: True},
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -712,7 +712,7 @@ class Embedding2D(ParallelLayer): ...@@ -712,7 +712,7 @@ class Embedding2D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class VocabParallelEmbedding2D(torch.nn.Module): class VocabParallelEmbedding2D(ParallelLayer):
r"""Embedding parallelized in the vocabulary dimension. r"""Embedding parallelized in the vocabulary dimension.
Args: Args:
...@@ -789,7 +789,7 @@ class VocabParallelEmbedding2D(torch.nn.Module): ...@@ -789,7 +789,7 @@ class VocabParallelEmbedding2D(torch.nn.Module):
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0) self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
...@@ -814,9 +814,9 @@ class VocabParallelEmbedding2D(torch.nn.Module): ...@@ -814,9 +814,9 @@ class VocabParallelEmbedding2D(torch.nn.Module):
partition_states={weight_key: True}, partition_states={weight_key: True},
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -924,7 +924,7 @@ class Classifier2D(ParallelLayer): ...@@ -924,7 +924,7 @@ class Classifier2D(ParallelLayer):
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL) broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL)
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW) broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -968,9 +968,9 @@ class Classifier2D(ParallelLayer): ...@@ -968,9 +968,9 @@ class Classifier2D(ParallelLayer):
}, },
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict() local_state = OrderedDict()
...@@ -1095,7 +1095,7 @@ class VocabParallelClassifier2D(ParallelLayer): ...@@ -1095,7 +1095,7 @@ class VocabParallelClassifier2D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -1139,9 +1139,9 @@ class VocabParallelClassifier2D(ParallelLayer): ...@@ -1139,9 +1139,9 @@ class VocabParallelClassifier2D(ParallelLayer):
}, },
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict() local_state = OrderedDict()
......
...@@ -96,7 +96,7 @@ class Linear2p5D(ParallelLayer): ...@@ -96,7 +96,7 @@ class Linear2p5D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -143,9 +143,9 @@ class Linear2p5D(ParallelLayer): ...@@ -143,9 +143,9 @@ class Linear2p5D(ParallelLayer):
}, },
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0: if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0:
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer): ...@@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -314,9 +314,9 @@ class LayerNorm2p5D(ParallelLayer): ...@@ -314,9 +314,9 @@ class LayerNorm2p5D(ParallelLayer):
}, },
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -463,7 +463,7 @@ class PatchEmbedding2p5D(ParallelLayer): ...@@ -463,7 +463,7 @@ class PatchEmbedding2p5D(ParallelLayer):
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed) position_embed_initializer(self.pos_embed)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -523,9 +523,9 @@ class PatchEmbedding2p5D(ParallelLayer): ...@@ -523,9 +523,9 @@ class PatchEmbedding2p5D(ParallelLayer):
}, },
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token' cls_token_key = prefix + 'cls_token'
...@@ -671,7 +671,7 @@ class Embedding2p5D(ParallelLayer): ...@@ -671,7 +671,7 @@ class Embedding2p5D(ParallelLayer):
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
...@@ -696,9 +696,9 @@ class Embedding2p5D(ParallelLayer): ...@@ -696,9 +696,9 @@ class Embedding2p5D(ParallelLayer):
partition_states={weight_key: True}, partition_states={weight_key: True},
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -733,7 +733,7 @@ class Embedding2p5D(ParallelLayer): ...@@ -733,7 +733,7 @@ class Embedding2p5D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class VocabParallelEmbedding2p5D(torch.nn.Module): class VocabParallelEmbedding2p5D(ParallelLayer):
"""Embedding parallelized in the vocabulary dimension. """Embedding parallelized in the vocabulary dimension.
Args: Args:
...@@ -810,7 +810,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module): ...@@ -810,7 +810,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0) self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
...@@ -835,9 +835,9 @@ class VocabParallelEmbedding2p5D(torch.nn.Module): ...@@ -835,9 +835,9 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
partition_states={weight_key: True}, partition_states={weight_key: True},
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -950,7 +950,7 @@ class Classifier2p5D(ParallelLayer): ...@@ -950,7 +950,7 @@ class Classifier2p5D(ParallelLayer):
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL) broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL)
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW) broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -994,9 +994,9 @@ class Classifier2p5D(ParallelLayer): ...@@ -994,9 +994,9 @@ class Classifier2p5D(ParallelLayer):
}, },
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict() local_state = OrderedDict()
...@@ -1123,7 +1123,7 @@ class VocabParallelClassifier2p5D(ParallelLayer): ...@@ -1123,7 +1123,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -1167,7 +1167,7 @@ class VocabParallelClassifier2p5D(ParallelLayer): ...@@ -1167,7 +1167,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
}, },
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q] # input: [m/dq, n/q, k/q]
......
...@@ -70,7 +70,7 @@ class LayerNorm3D(ParallelLayer): ...@@ -70,7 +70,7 @@ class LayerNorm3D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
init.zeros_()(self.bias) init.zeros_()(self.bias)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -105,9 +105,9 @@ class LayerNorm3D(ParallelLayer): ...@@ -105,9 +105,9 @@ class LayerNorm3D(ParallelLayer):
# broadcast in weight groups # broadcast in weight groups
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -207,7 +207,7 @@ class Linear3D(ParallelLayer): ...@@ -207,7 +207,7 @@ class Linear3D(ParallelLayer):
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -265,9 +265,9 @@ class Linear3D(ParallelLayer): ...@@ -265,9 +265,9 @@ class Linear3D(ParallelLayer):
}, },
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -400,7 +400,7 @@ class Classifier3D(ParallelLayer): ...@@ -400,7 +400,7 @@ class Classifier3D(ParallelLayer):
broadcast(self.bias, output_src_rank, self.output_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode) broadcast(self.bias, input_src_rank, self.input_parallel_mode)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -437,9 +437,9 @@ class Classifier3D(ParallelLayer): ...@@ -437,9 +437,9 @@ class Classifier3D(ParallelLayer):
# broadcast in weight groups # broadcast in weight groups
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict() local_state = OrderedDict()
...@@ -551,7 +551,7 @@ class VocabParallelClassifier3D(ParallelLayer): ...@@ -551,7 +551,7 @@ class VocabParallelClassifier3D(ParallelLayer):
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -610,9 +610,9 @@ class VocabParallelClassifier3D(ParallelLayer): ...@@ -610,9 +610,9 @@ class VocabParallelClassifier3D(ParallelLayer):
}, },
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -763,7 +763,7 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -763,7 +763,7 @@ class PatchEmbedding3D(ParallelLayer):
self.cls_token.register_hook(self._sync_grad_hook) self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook) self.pos_embed.register_hook(self._sync_grad_hook)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
...@@ -812,9 +812,9 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -812,9 +812,9 @@ class PatchEmbedding3D(ParallelLayer):
# broadcast in weight groups # broadcast in weight groups
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
bias_key = prefix + 'bias' bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token' cls_token_key = prefix + 'cls_token'
...@@ -937,7 +937,7 @@ class Embedding3D(ParallelLayer): ...@@ -937,7 +937,7 @@ class Embedding3D(ParallelLayer):
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
...@@ -961,9 +961,9 @@ class Embedding3D(ParallelLayer): ...@@ -961,9 +961,9 @@ class Embedding3D(ParallelLayer):
# broadcast in weight groups # broadcast in weight groups
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
...@@ -991,7 +991,7 @@ class Embedding3D(ParallelLayer): ...@@ -991,7 +991,7 @@ class Embedding3D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class VocabParallelEmbedding3D(torch.nn.Module): class VocabParallelEmbedding3D(ParallelLayer):
r"""Embedding parallelized in the vocabulary dimension. r"""Embedding parallelized in the vocabulary dimension.
Args: Args:
...@@ -1070,7 +1070,7 @@ class VocabParallelEmbedding3D(torch.nn.Module): ...@@ -1070,7 +1070,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0) self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
...@@ -1104,9 +1104,9 @@ class VocabParallelEmbedding3D(torch.nn.Module): ...@@ -1104,9 +1104,9 @@ class VocabParallelEmbedding3D(torch.nn.Module):
partition_states={weight_key: True}, partition_states={weight_key: True},
) )
super()._load_from_state_dict(local_state, prefix, *args, **kwargs) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight' weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight}) local_state = OrderedDict({weight_key: self.weight})
......
...@@ -3,9 +3,9 @@ from itertools import chain ...@@ -3,9 +3,9 @@ from itertools import chain
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.communication.collective import scatter_object_list
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.constants import IS_TENSOR_PARALLEL
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError: except ImportError:
...@@ -190,6 +190,15 @@ def save_checkpoint(file, ...@@ -190,6 +190,15 @@ def save_checkpoint(file,
torch.save(checkpoint, file, **kwargs) torch.save(checkpoint, file, **kwargs)
def broadcast_model(model: torch.nn.Module):
src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0]
for p in model.parameters():
if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0:
group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group(
ParallelMode.TENSOR)
dist.broadcast(p, src_rank, group=group)
def load_checkpoint( def load_checkpoint(
file, file,
model: torch.nn.Module, model: torch.nn.Module,
...@@ -225,6 +234,7 @@ def load_checkpoint( ...@@ -225,6 +234,7 @@ def load_checkpoint(
model_state = partition_pipeline_parallel_state_dict(model, model_state) model_state = partition_pipeline_parallel_state_dict(model, model_state)
try: try:
model.load_state_dict(model_state, strict=strict) model.load_state_dict(model_state, strict=strict)
broadcast_model(model)
except RuntimeError as e: except RuntimeError as e:
error_msgs = str(e) error_msgs = str(e)
if error_msgs.startswith("Error(s) in loading state_dict for "): if error_msgs.startswith("Error(s) in loading state_dict for "):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment