"vscode:/vscode.git/clone" did not exist on "d06980dfa7e8dcc1738656beb46d3735c86faa21"
Unverified Commit e4f506a0 authored by hugo-syn's avatar hugo-syn Committed by GitHub
Browse files

chore: Fix multiple typos (#617)


Signed-off-by: default avatarhugo-syn <hugo.vincent@synacktiv.com>
parent 051db0d7
...@@ -37,7 +37,7 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd ...@@ -37,7 +37,7 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd
# [GPU-0] Peak memory use = 3000MiB # [GPU-0] Peak memory use = 3000MiB
# FSDP with deferred initialization: # FSDP with deferred initialization:
# Modules initialized with empty paramaters via `device='meta'` option. Zero load on device # Modules initialized with empty parameters via `device='meta'` option. Zero load on device
# memory until torch.distributed.fsdp.FullyShardedDataParallel mode triggers a reset on # memory until torch.distributed.fsdp.FullyShardedDataParallel mode triggers a reset on
# on already sharded model parameters. # on already sharded model parameters.
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py --defer-init $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py --defer-init
......
...@@ -250,7 +250,7 @@ class FusedAttnRunner: ...@@ -250,7 +250,7 @@ class FusedAttnRunner:
self._setup_inputs() self._setup_inputs()
def grad_func(func, *args, **kwargs): def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident # Gradient is small, use a gradient multiplier to amplify the gradient
gradient_multiplier = self.valid_len_q * self.num_heads_q gradient_multiplier = self.valid_len_q * self.num_heads_q
if is_causal_mask(self.attn_mask_type): if is_causal_mask(self.attn_mask_type):
gradient_multiplier /= 10 gradient_multiplier /= 10
......
...@@ -204,7 +204,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -204,7 +204,7 @@ class TestFP8Functions(unittest.TestCase):
(MeshResource(None, 'tp')), (MeshResource(None, 'tp')),
(MeshResource('dp', 'tp')), (MeshResource('dp', 'tp')),
) )
# TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme # TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1) mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
with jax.sharding.Mesh(devices, ('dp', 'tp')): with jax.sharding.Mesh(devices, ('dp', 'tp')):
......
...@@ -100,7 +100,7 @@ class TestAttentionTp(unittest.TestCase): ...@@ -100,7 +100,7 @@ class TestAttentionTp(unittest.TestCase):
paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
if interleave: if interleave:
# Due to the interleaved qkv layout, need to concat on num_head # Due to the interleaved qkv layout, need to concat on num_head
# dimention for column parallel linear in MultiHeadAttention layer # dimension for column parallel linear in MultiHeadAttention layer
assert axis == 0 assert axis == 0
assert [3 * self.hidden_size // self.world_size, assert [3 * self.hidden_size // self.world_size,
self.hidden_size] == partial_weight.shape self.hidden_size] == partial_weight.shape
......
...@@ -101,7 +101,7 @@ class TestTransformerTp(unittest.TestCase): ...@@ -101,7 +101,7 @@ class TestTransformerTp(unittest.TestCase):
paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
if interleave: if interleave:
# Due to the interleaved qkv layout, need to concat on num_head # Due to the interleaved qkv layout, need to concat on num_head
# dimention for column parallel linear in MultiHeadAttention layer # dimension for column parallel linear in MultiHeadAttention layer
assert axis == 0 assert axis == 0
assert [3 * self.hidden_size // self.world_size, assert [3 * self.hidden_size // self.world_size,
self.hidden_size] == partial_weight.shape self.hidden_size] == partial_weight.shape
......
...@@ -668,7 +668,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f ...@@ -668,7 +668,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"]) @pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"]) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"])
def test_te_layer_misc(dtype, model_configs, model, qkv_format): def test_te_layer_misc(dtype, model_configs, model, qkv_format):
"""Test TransformerLayer module with miscellanous settings""" """Test TransformerLayer module with miscellaneous settings"""
ckpt_attn = True ckpt_attn = True
fused_qkv_params = True fused_qkv_params = True
RoPE = True RoPE = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment