"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "2a951955ade14fd067bc5bee34a5ff7e57513ac6"
Unverified Commit 2ac24040 authored by digger yu's avatar digger yu Committed by GitHub
Browse files

fix some typo colossalai/shardformer (#4160)

parent c77b3b19
...@@ -252,7 +252,7 @@ class ModelSharder: ...@@ -252,7 +252,7 @@ class ModelSharder:
def shard(self) -> None: def shard(self) -> None:
""" """
Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing. Shard model with parallelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
""" """
... ...
......
...@@ -48,13 +48,13 @@ class DistCrossEntropy(Function): ...@@ -48,13 +48,13 @@ class DistCrossEntropy(Function):
# [down, up) => false, other device and -100 => true # [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size delta = (global_vocab_size + world_size - 1) // world_size
down_shreshold = rank * delta down_threshold = rank * delta
up_shreshold = down_shreshold + delta up_threshold = down_threshold + delta
mask = (target < down_shreshold) | (target >= up_shreshold) mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_shreshold masked_target = target.clone() - down_threshold
masked_target[mask] = 0 masked_target[mask] = 0
# reshape the logist and target # reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len] # reshape the labels to [bath_size * seq_len]
logits_2d = vocab_logits.view(-1, partition_vocab_size) logits_2d = vocab_logits.view(-1, partition_vocab_size)
...@@ -79,7 +79,7 @@ class DistCrossEntropy(Function): ...@@ -79,7 +79,7 @@ class DistCrossEntropy(Function):
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
loss = torch.sum(loss).div_(torch.sum(loss != 0.0)) loss = torch.sum(loss).div_(torch.sum(loss != 0.0))
# caculate the softmax # calculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, mask, masked_target_1d) ctx.save_for_backward(exp_logits, mask, masked_target_1d)
......
...@@ -66,7 +66,7 @@ class Policy(ABC): ...@@ -66,7 +66,7 @@ class Policy(ABC):
like BertPolicy for Bert Model or OPTPolicy for OPT model. like BertPolicy for Bert Model or OPTPolicy for OPT model.
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
built-in policies by setting `policy = None`, which is already the default arguemnt for `Shardformer.optimize`. built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify. If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
""" """
......
...@@ -73,7 +73,7 @@ class ModelSharder(object): ...@@ -73,7 +73,7 @@ class ModelSharder(object):
layer (torch.nn.Module): The object of layer to shard layer (torch.nn.Module): The object of layer to shard
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name. origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name.
attr_replacement (Dict): The attribute dict to modify attr_replacement (Dict): The attribute dict to modify
param_replacement (List[Callable]): The function list to get parameter shard information in polic param_replacement (List[Callable]): The function list to get parameter shard information in policy
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
""" """
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment