Commit ce02cd51 authored by dongcl's avatar dongcl
Browse files

Megatron v0.11.0

parent aeed6d97
......@@ -66,6 +66,10 @@ def unpermute(
):
```
### 项目支持使用[flux kernel](http://10.6.10.68/dcutoolkit/deeplearing/flux)
在tp场景下,用户可以选择使用flux通算融合算子,获得更好的训练和推理性能。项目通过替换transformer engine方法集成flux,使用时需要设置环境变量USE_FLUX_OVERLAP=1,并设置transformer-impl为transformer_engine。
### 使用方式
在使用时,进入到examples目录下,有相关模型执行脚本,所用数据集请自行下载:https://r0ddbu55vzx.feishu.cn/drive/folder/ZxHHfCoX4lg75td2hTqcmiAin3g
```
......
......@@ -99,7 +99,7 @@ class CoreAdaptation(MegatronAdaptationABC):
)
from ..core.models.gpt.gpt_model import (
gpt_model_forward,
gpt_model_init,
gpt_model_init_wrapper,
shared_embedding_or_mtp_embedding_weight
)
from ..training.utils import get_batch_on_this_tp_rank
......@@ -116,20 +116,20 @@ class CoreAdaptation(MegatronAdaptationABC):
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
from megatron.core.models.gpt.gpt_model import GPTModel
setattr(GPTModel, 'shared_embedding_or_mtp_embedding_weight', shared_embedding_or_mtp_embedding_weight)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper, transformer_block_forward
from ..core import transformer_block_init_wrapper
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_init_wrapper)
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.forward',
transformer_block_forward)
# Transformer config
MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig',
......@@ -141,9 +141,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
# torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
# apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
......@@ -166,7 +166,6 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel import ColumnParallelLinearPatch, RowParallelLinearPatch, parallel_linear_init_wrapper
# VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
......@@ -188,17 +187,19 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper=True)
# flux
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
ColumnParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward)
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
from ..core.tensor_parallel import (
FluxColumnParallelLinear,
FluxRowParallelLinear
)
from ..core.models.gpt.gpt_layer_specs import get_gpt_layer_with_flux_spec
MegatronAdaptation.register("megatron.core.extensions.transformer_engine.TEColumnParallelLinear",
FluxColumnParallelLinear)
MegatronAdaptation.register("megatron.core.extensions.transformer_engine.TERowParallelLinear",
FluxRowParallelLinear)
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec)
def patch_training(self):
from ..training.tokenizer import build_tokenizer
......@@ -232,19 +233,22 @@ class LegacyAdaptation(MegatronAdaptationABC):
self.patch_legacy_models()
def patch_legacy_models(self):
from ..legacy.model.transformer import ParallelMLP, ParallelAttention
from ..legacy.model.transformer import ParallelMLPPatch, ParallelAttentionPatch
from ..legacy.model.utils import get_norm
# ParallecMLP
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelMLP.__init__',
ParallelMLP.__init__)
ParallelMLPPatch.__init__)
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.forward',
ParallelAttention.forward)
ParallelAttentionPatch.forward)
# rms_norm.RMSNorm
MegatronAdaptation.register('megatron.legacy.model.rms_norm.RMSNorm.forward',
torch.compile(mode="max-autotune-no-cudagraphs"),
apply_wrapper=True)
MegatronAdaptation.register('megatron.legacy.model.utils.get_norm',
get_norm)
MegatronAdaptation.execute()
from .transformer.transformer_block import transformer_block_init_wrapper, transformer_block_forward
from .transformer.transformer_block import transformer_block_init_wrapper
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from typing import List
import torch
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from typing import Literal
import torch
......
import warnings
from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.multi_latent_attention import (
MLASelfAttention,
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
)
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear, FluxRowParallelLinear
from megatron.core.utils import is_te_min_version
try:
from megatron.core.extensions.transformer_engine import (
TEDotProductAttention,
TENorm,
)
except ImportError:
warnings.warn('transformer_engine is not installed.')
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
except ImportError:
warnings.warn('Apex is not installed.')
def get_gpt_layer_with_flux_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec to use flux modules (required for fp8 training).
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with flux modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp = get_mlp_module_flux_spec(
use_te=False,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
if multi_latent_attention:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=FluxColumnParallelLinear,
linear_q_down_proj=FluxColumnParallelLinear,
linear_q_up_proj=FluxColumnParallelLinear,
linear_kv_down_proj=FluxColumnParallelLinear,
linear_kv_up_proj=FluxColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=FluxRowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
kv_layernorm=TENorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
else:
# TENorm significantly harms convergence when used
# for QKLayerNorm if TE Version < 1.9;
# we instead use the Apex implementation.
qk_norm = TENorm if is_te_min_version("1.9.0") else FusedLayerNorm
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=FluxColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=FluxRowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
k_layernorm=qk_norm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def get_mlp_module_flux_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MLP/MoE"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=FluxColumnParallelLinear,
linear_fc2=FluxRowParallelLinear,
),
)
else:
# Mixture of experts with modules in megatron core.
return get_moe_module_spec(
use_te=True,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
import os
import logging
from typing import Literal, Optional
from functools import wraps
......@@ -16,160 +17,76 @@ from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear
from dcu_megatron.core.utils import tensor_slide
from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor
from dcu_megatron.core.transformer.transformer_config import TransformerConfig
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
def gpt_model_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if (
self.post_process
and int(os.getenv("USE_FLUX_OVERLAP", "0"))
):
self.output_layer = FluxColumnParallelLinear(
self.config.hidden_size,
self.vocab_size,
config=self.config,
init_method=self.config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
def gpt_model_init(
self,
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
mtp_spec: ModuleSpec = None
) -> None:
super(GPTModel, self).__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
rope_scaling_factor=rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Cache for RoPE tensors which do not change between iterations.
self.rotary_pos_emb_cache = {}
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process
)
# Output
if post_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
# add mtp
self.mtp_spec: ModuleSpec = mtp_spec
self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
self.share_mtp_embedding_and_output_weight = self.config.share_mtp_embedding_and_output_weight
self.recompute_mtp_norm = self.config.recompute_mtp_norm
self.recompute_mtp_layer = self.config.recompute_mtp_layer
self.mtp_loss_scale = self.config.mtp_loss_scale
if self.post_process and self.training and self.num_nextn_predict_layers:
self.mtp_layers = torch.nn.ModuleList(
[
MultiTokenPredictor(
config,
self.mtp_spec.submodules,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
layer_number=i,
pre_process=self.pre_process,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=seq_len_interpolation_factor,
share_mtp_embedding_and_output_weight=self.share_mtp_embedding_and_output_weight,
recompute_mtp_norm=self.recompute_mtp_norm,
recompute_mtp_layer=self.recompute_mtp_layer,
add_output_layer_bias=False
self.setup_embeddings_and_output_layer()
# add mtp
self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
if self.num_nextn_predict_layers:
assert hasattr(self.config, "mtp_spec")
self.mtp_spec: ModuleSpec = self.config.mtp_spec
self.share_mtp_embedding_and_output_weight = self.config.share_mtp_embedding_and_output_weight
self.recompute_mtp_norm = self.config.recompute_mtp_norm
self.recompute_mtp_layer = self.config.recompute_mtp_layer
self.mtp_loss_scale = self.config.mtp_loss_scale
if self.post_process and self.training:
self.mtp_layers = torch.nn.ModuleList(
[
MultiTokenPredictor(
self.config,
self.mtp_spec.submodules,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
layer_number=i,
pre_process=self.pre_process,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=seq_len_interpolation_factor,
share_mtp_embedding_and_output_weight=self.share_mtp_embedding_and_output_weight,
recompute_mtp_norm=self.recompute_mtp_norm,
recompute_mtp_layer=self.recompute_mtp_layer,
add_output_layer_bias=False
)
for i in range(self.num_nextn_predict_layers)
]
)
for i in range(self.num_nextn_predict_layers)
]
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
if has_config_logger_enabled(self.config):
log_config_to_disk(
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
if self.pre_process or self.post_process:
setup_mtp_embeddings(self)
if self.num_nextn_predict_layers and (self.pre_process or self.post_process):
setup_mtp_embeddings(self)
return wrapper
def shared_embedding_or_mtp_embedding_weight(self) -> Tensor:
......@@ -424,10 +341,10 @@ def gpt_model_forward(
if (
self.num_nextn_predict_layers
and getattr(self.decoder, "final_layernorm", None) is not None
and getattr(self.decoder, "main_final_layernorm", None) is not None
):
# move block main model final norms here
hidden_states = self.decoder.final_layernorm(hidden_states)
hidden_states = self.decoder.main_final_layernorm(hidden_states)
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
......
from .layers import (
parallel_linear_init_wrapper,
ColumnParallelLinearPatch,
RowParallelLinearPatch,
FluxColumnParallelLinear,
FluxRowParallelLinear,
vocab_parallel_embedding_forward,
vocab_parallel_embedding_init,
)
\ No newline at end of file
import os
import socket
import warnings
from functools import wraps
from typing import Callable, List, Optional
import flux
try:
import flux
except ImportError:
raise ImportError("flux is NOT installed")
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.training import print_rank_0
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
......@@ -26,15 +34,24 @@ from megatron.core.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
_reduce_scatter_along_first_dim,
_gather_along_first_dim,
)
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.tensor_parallel.mappings import _reduce
from megatron.core.tensor_parallel import (
ColumnParallelLinear,
RowParallelLinear,
)
from megatron.core.tensor_parallel.layers import (
custom_fwd,
custom_bwd,
dist_all_gather_func,
linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce
)
from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True
try:
......@@ -146,6 +163,19 @@ def vocab_parallel_embedding_forward(self, input_, weight=None):
return output
def get_tensor_model_parallel_node_size(group=None):
""" 获取节点数
"""
if group is None:
group=get_tensor_model_parallel_group()
hostname = socket.gethostname()
hostnames = [None] * get_tensor_model_parallel_world_size()
torch.distributed.all_gather_object(hostnames, hostname, group=group)
num_nodes = len(set(hostnames))
return num_nodes
class AGLinear(torch.autograd.Function):
@staticmethod
@custom_fwd
......@@ -160,6 +190,8 @@ class AGLinear(torch.autograd.Function):
grad_output_buffer,
wgrad_deferral_limit,
transpose_weight=False,
fw_ag_gemm_op=None,
bw_gemm_rs_op=None,
):
"""Forward."""
ctx.save_for_backward(input, weight)
......@@ -170,63 +202,44 @@ class AGLinear(torch.autograd.Function):
ctx.wgrad_deferral_limit = wgrad_deferral_limit
ctx.grad_output_buffer = grad_output_buffer
ctx.transpose_weight = transpose_weight
ctx.bw_gemm_rs_op = bw_gemm_rs_op
sequence_len = input.size(0)
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input = input.view(
input.shape[0] * input.shape[1], input.shape[2]
)
M, K = list(input.size())
N = weight.size(0)
M = M * get_tensor_model_parallel_world_size()
if sequence_parallel:
sequence_len, batch_size, input_hidden_size = input.size()
output_hidden_size = weight.size(0)
world_size = get_tensor_model_parallel_world_size()
if transpose_weight:
weight = weight.t().contiguous()
if fw_ag_gemm_op is None:
if not is_flux_min_version("1.1.0"):
fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_hidden_size,
input.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
if sequence_parallel:
ag_gemm_kernel = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_world_size() // torch.cuda.device_count(),
M,
N,
K,
input.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
output = ag_gemm_kernel.forward(
input,
weight,
bias=bias,
input_scale=input_scale,
weight_scale=weight_scale,
output_scale=None,
fast_accum=False
)
else:
output_buf = torch.empty([M, N], dtype=input.dtype, device=input.device)
gemm_only_op = flux.GemmOnly(
input_dtype=input.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
use_fp8_gemm=False,
)
output = gemm_only_op.forward(
input,
weight,
output = fw_ag_gemm_op.forward(
input.view(sequence_len * batch_size, -1),
weight.t().contiguous() if transpose_weight else weight,
bias=bias,
output_buf=output_buf,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=False,
fast_accum=False
)
torch.cuda.current_stream().synchronize()
output = output.view(sequence_len, input.size(0) // sequence_len, -1)
torch.cuda.current_stream().synchronize()
output = output.view(sequence_len * world_size, batch_size, -1)
else:
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output
......@@ -239,8 +252,9 @@ class AGLinear(torch.autograd.Function):
grad_output_buffer = ctx.grad_output_buffer
wgrad_deferral_limit = ctx.wgrad_deferral_limit
transpose_weight = ctx.transpose_weight
bw_gemm_rs_op = ctx.bw_gemm_rs_op
wgrad_compute = True
wgrad_compute = weight.requires_grad
if grad_output_buffer is not None:
if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
grad_output_buffer.append(grad_output)
......@@ -266,29 +280,25 @@ class AGLinear(torch.autograd.Function):
total_input = input
if ctx.sequence_parallel:
sequence_len, batch_size, output_hidden_size = grad_output.size()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
grad_output = grad_output.view(
sequence_len * batch_size, output_hidden_size
)
sequence_len, batch_size, _ = grad_output.size()
if not transpose_weight:
weight = weight.t().contiguous()
if bw_gemm_rs_op is None:
input_hidden_size = weight.size(-1)
if not is_flux_min_version("1.1.0"):
bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
input_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
fuse_reduction=False
)
gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
world_size // torch.cuda.device_count(),
sequence_len * batch_size,
output_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
fuse_reduction=False
)
grad_input = gemm_rs_op.forward(
grad_output,
weight,
grad_input = bw_gemm_rs_op.forward(
grad_output.view(sequence_len * batch_size, -1),
weight if transpose_weight else weight.t().contiguous(),
bias=None,
input_scale=None,
weight_scale=None,
......@@ -297,7 +307,7 @@ class AGLinear(torch.autograd.Function):
)
torch.cuda.current_stream().synchronize()
grad_input = grad_input.view(sequence_len // get_tensor_model_parallel_group(), batch_size, -1)
grad_input = grad_input.view(sequence_len // world_size, batch_size, -1)
else:
grad_input = grad_output.matmul(weight)
......@@ -310,12 +320,14 @@ class AGLinear(torch.autograd.Function):
)
if not ctx.sequence_parallel and ctx.allreduce_dgrad:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if weight.requires_grad:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
else:
grad_input = _reduce(grad_input)
return grad_input, None, None, None, None, None, None, None, None, None, None
if ctx.gradient_accumulation_fusion:
if wgrad_compute:
......@@ -356,10 +368,10 @@ class AGLinear(torch.autograd.Function):
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.allreduce_dgrad:
if not ctx.sequence_parallel and ctx.allreduce_dgrad:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None, None
def ag_linear(
......@@ -372,6 +384,8 @@ def ag_linear(
grad_output_buffer: Optional[List[torch.Tensor]] = None,
wgrad_deferral_limit: Optional[int] = 0,
transpose_weight: Optional[bool] = False,
fw_ag_gemm_op=None,
bw_gemm_rs_op=None
) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
......@@ -433,6 +447,11 @@ def ag_linear(
deferred. Disable by setting this to 0. Defaults to 0.
transpose_weight: transpose weight.
fw_ag_gemm_op: flux AGKernel for forward.
bw_gemm_rs_op: flux GemmRS for backward.
"""
args = [
......@@ -445,6 +464,8 @@ def ag_linear(
grad_output_buffer,
wgrad_deferral_limit,
transpose_weight,
fw_ag_gemm_op,
bw_gemm_rs_op,
]
if not ag_linear.warned:
......@@ -485,6 +506,8 @@ class LinearRS(torch.autograd.Function):
grad_output_buffer,
wgrad_deferral_limit,
transpose_weight=False,
fw_gemm_rs_op=None,
bw_ag_gemm_op=None
):
"""Forward."""
ctx.save_for_backward(input, weight)
......@@ -495,66 +518,40 @@ class LinearRS(torch.autograd.Function):
ctx.wgrad_deferral_limit = wgrad_deferral_limit
ctx.grad_output_buffer = grad_output_buffer
ctx.transpose_weight = transpose_weight
ctx.bw_ag_gemm_op = bw_ag_gemm_op
world_size = get_tensor_model_parallel_world_size()
input_dim = input.dim()
sequence_len = input.size(0)
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input = input.view(
input.shape[0] * input.shape[1], input.shape[2]
)
M = input.size(0)
N = weight.size(0)
sequence_len, batch_size, _ = input.size()
output_hidden_size = weight.size(0)
if sequence_parallel:
if transpose_weight:
weight = weight.t().contiguous()
gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
world_size // torch.cuda.device_count(),
M,
N,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
fuse_reduction=False,
)
output = gemm_rs_op.forward(
input,
weight,
if fw_gemm_rs_op is None:
if not is_flux_min_version("1.1.0"):
fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
output_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
fuse_reduction=False,
)
output = fw_gemm_rs_op.forward(
input.view(sequence_len * batch_size, -1),
weight.t().contiguous() if transpose_weight else weight,
bias=bias,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=False,
)
torch.cuda.current_stream().synchronize()
output = output.view(sequence_len // world_size, batch_size, -1)
else:
output = torch.empty([M, N], dtype=input.dtype, device=input.device)
gemm_only_op = flux.GemmOnly(
input_dtype=input.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
use_fp8_gemm=False,
)
output = gemm_only_op.forward(
input,
weight,
bias=bias,
output_buf=output,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=False,
)
torch.cuda.current_stream().synchronize()
output = output.view(sequence_len, input.size(0) // sequence_len, -1)
if not sequence_parallel:
_reduce(output)
output = torch.matmul(input, weight.t())
return output
......@@ -567,69 +564,86 @@ class LinearRS(torch.autograd.Function):
grad_output_buffer = ctx.grad_output_buffer
wgrad_deferral_limit = ctx.wgrad_deferral_limit
transpose_weight = ctx.transpose_weight
bw_ag_gemm_op = ctx.bw_ag_gemm_op
wgrad_compute = True
wgrad_compute = weight.requires_grad
if grad_output_buffer is not None:
if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
grad_output_buffer.append(grad_output)
wgrad_compute = False
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
sequence_len, batch_size, _ = grad_output.size()
grad_output = grad_output.view(sequence_len * batch_size, -1)
M, K = list(grad_output.size())
M = M * world_size
N = weight.size(-1)
if not transpose_weight:
weight = weight.t().contiguous()
grad_input = torch.empty([M, N], dtype=input.dtype, device=input.device)
ag_kernel = flux.AGKernel(
get_tensor_model_parallel_group(),
world_size // torch.cuda.device_count(),
M,
N,
K,
input.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
if wgrad_compute:
if ctx.sequence_parallel:
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = get_global_memory_buffer().get_tensor(
dim_size, grad_output.dtype, "mpu"
)
handle = dist_all_gather_func(
all_gather_buffer, grad_output, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_grad_output = all_gather_buffer
else:
total_grad_output = grad_output
output = ag_kernel.forward(
grad_output,
weight,
if ctx.sequence_parallel:
sequence_len, batch_size, output_hidden_size = grad_output.size()
input_hidden_size = weight.size(-1)
if bw_ag_gemm_op is None:
if not is_flux_min_version("1.1.0"):
bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
input_hidden_size,
output_hidden_size,
grad_output.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
grad_input = bw_ag_gemm_op.forward(
grad_output.view(sequence_len * batch_size, -1),
weight if transpose_weight else weight.t().contiguous(),
bias=None,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=False,
)
torch.cuda.current_stream().synchronize()
grad_input = grad_input.view(sequence_len * world_size, batch_size, -1)
else:
grad_input = grad_output.matmul(weight)
if not weight.requires_grad:
grad_input, None, None, None, None, None, None, None, None, None, None
if ctx.sequence_parallel and wgrad_compute:
handle.wait()
if wgrad_compute:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
grad_output, input
total_grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
total_grad_output, input
)
if ctx.gradient_accumulation_fusion:
if wgrad_compute:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
total_input, grad_output, weight.main_grad
total_input, total_grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
total_input, grad_output, weight.main_grad
total_input, total_grad_output, weight.main_grad
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
......@@ -657,10 +671,10 @@ class LinearRS(torch.autograd.Function):
else:
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
grad_weight = total_grad_output.t().matmul(total_input)
grad_bias = total_grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None, None
def linear_rs(
......@@ -673,6 +687,8 @@ def linear_rs(
grad_output_buffer: Optional[List[torch.Tensor]] = None,
wgrad_deferral_limit: Optional[int] = 0,
transpose_weight: Optional[bool] = False,
fw_gemm_rs_op=None,
bw_ag_gemm_op=None,
) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
......@@ -734,6 +750,11 @@ def linear_rs(
deferred. Disable by setting this to 0. Defaults to 0.
transpose_weight: transpose weight.
fw_gemm_rs_op: flux AGKernel for forward.
bw_ag_gemm_op: flux GemmRS for backward.
"""
args = [
......@@ -746,6 +767,8 @@ def linear_rs(
grad_output_buffer,
wgrad_deferral_limit,
transpose_weight,
fw_gemm_rs_op,
bw_ag_gemm_op,
]
if not linear_rs.warned:
......@@ -772,35 +795,99 @@ def linear_rs(
linear_rs.warned = False
def parallel_linear_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
# flux params
self.use_flux = False
if "use_flux" in kwargs:
self.use_flux = kwargs["use_flux"]
elif hasattr(self.config, "use_flux"):
self.use_flux = self.config.use_flux
self.flux_transpose_weight = False
if "flux_transpose_weight" in kwargs:
self.flux_transpose_weight = kwargs["flux_transpose_weight"]
elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight
return wrapper
class ColumnParallelLinearPatch(torch.nn.Module):
class FluxColumnParallelLinear(ColumnParallelLinear):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Args:
input_size:
first dimension of matrix A.
output_size:
second dimension of matrix A.
bias:
If true, add bias
gather_output:
If true, call all-gather on output and make Y available to all GPUs,
otherwise, every GPU will have its output which is Y_i = XA_i
init_method:
method to initialize weights. Note that bias is always set to zero.
stride:
For the strided linear layers.
keep_master_weight_for_test:
This was added for testing and should be set to False. It
returns the master weights used for initialization.
skip_bias_add:
If True, do not add the bias term, instead return it to be added by the
caller. This enables performance optimations where bias can be fused with other
elementwise operations.
skip_weight_param_allocation:
If True, weight parameter is not allocated and must be passed
as a keyword argument `weight` during the forward pass. Note that this does not
affect bias, which will be allocated if bias is True. Defaults to False.
embedding_activation_buffer:
This buffer holds the input activations of the final embedding
linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
grad_output_buffer:
This buffer holds the gradient outputs of the final embedding linear
layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
is_expert:
If True, the layer is treated as an MoE expert layer.
config:
ModelParallelConfig object
tp_comm_buffer_name:
Communication buffer name is not used in non-Transformer-Engine modules.
disable_grad_reduce:
If True, reduction of output gradients across tensor-parallel ranks
will be disabled. Defaults to False. This feature is used by Lora Adapter in Nemo to
delay and fuse reduction along with other gradients for performance optimization.
"""
def __init__(
self,
input_size,
output_size,
*,
config: ModelParallelConfig,
init_method: Callable,
bias=True,
gather_output=False,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
skip_weight_param_allocation: bool = False,
embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
disable_grad_reduce: bool = False,
):
super(FluxColumnParallelLinear, self).__init__(
input_size=input_size,
output_size=output_size,
config=config,
init_method=init_method,
bias=bias,
gather_output=gather_output,
stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=skip_weight_param_allocation,
embedding_activation_buffer=embedding_activation_buffer,
grad_output_buffer=grad_output_buffer,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
disable_grad_reduce=disable_grad_reduce,
)
# flux params
self._forward_impl = ag_linear
self.flux_transpose_weight = getattr(self.config, "flux_transpose_weight", False)
self.previous_flux_params = (None,) * 5
self.fw_ag_gemm_op = None
self.bw_gemm_rs_op = None
def forward(
self,
input_: torch.Tensor,
......@@ -863,30 +950,65 @@ class ColumnParallelLinearPatch(torch.nn.Module):
):
self.embedding_activation_buffer.append(input_parallel)
# Matrix multiply.
if self.use_flux:
self._forward_impl = ag_linear
elif not weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
# flux kernels.
if self.sequence_parallel:
sequence_len, batch_size, input_hidden_size = input_parallel.size()
output_hidden_size = weight.size(0)
world_size = get_tensor_model_parallel_world_size()
current_flux_params = (
sequence_len,
batch_size,
input_hidden_size,
output_hidden_size,
input_parallel.dtype
)
if (
self.fw_ag_gemm_op is None
or current_flux_params != self.previous_flux_params
):
if not is_flux_min_version("1.1.0"):
self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_hidden_size,
input_parallel.dtype,
output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
input_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
fuse_reduction=False
)
self.previous_flux_params = current_flux_params
allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad
forward_params = {
"input": input_parallel,
"weight": weight,
"bias": bias,
"gradient_accumulation_fusion": self.gradient_accumulation_fusion,
"allreduce_dgrad": allreduce_dgrad,
"sequence_parallel": False if self.explicit_expert_comm else self.sequence_parallel,
"grad_output_buffer": self.grad_output_buffer if self.config.defer_embedding_wgrad_compute else None,
"wgrad_deferral_limit": self.config.wgrad_deferral_limit if self.config.defer_embedding_wgrad_compute else None,
}
if self.use_flux:
forward_params.update({"transpose_weight": self.flux_transpose_weight})
output_parallel = self._forward_impl(**forward_params)
output_parallel = self._forward_impl(
input=input_parallel,
weight=weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
allreduce_dgrad=allreduce_dgrad,
sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=self.grad_output_buffer if self.config.defer_embedding_wgrad_compute else None,
wgrad_deferral_limit=self.config.wgrad_deferral_limit if self.config.defer_embedding_wgrad_compute else None,
transpose_weight=self.flux_transpose_weight,
fw_ag_gemm_op=self.fw_ag_gemm_op,
bw_gemm_rs_op=self.bw_gemm_rs_op
)
gather_output = self.gather_output
# Use the runtime gather output if it's set explicitly.
......@@ -902,15 +1024,89 @@ class ColumnParallelLinearPatch(torch.nn.Module):
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def __repr__(self):
tp = self.output_size // self.output_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size_per_partition}, bias={use_bias}, TP={tp})"
)
class RowParallelLinearPatch(torch.nn.Module):
class FluxRowParallelLinear(RowParallelLinear):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X
along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p]
Args:
input_size:
first dimension of matrix A.
output_size:
second dimension of matrix A.
bias:
If true, add bias. Note that bias is not parallelized.
input_is_parallel:
If true, we assume that the input is already split across the GPUs
and we do not split again.
init_method:
method to initialize weights. Note that bias is always set to zero.
stride:
For the strided linear layers.
keep_master_weight_for_test:
This was added for testing and should be set to False. It returns the master weights
used for initialization.
skip_bias_add:
If True, do not add the bias term, instead return it to be added by the
caller. This enables performance optimations where bias can be fused with other
elementwise operations.
is_expert:
If True, the layer is treated as an MoE expert layer
tp_comm_buffer_name:
Communication buffer name. Not used in non-Transformer-Engine modules.
config:
ModelParallelConfig object
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
stride: int = 1,
keep_master_weight_for_test: bool = False,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
):
super(FluxRowParallelLinear, self).__init__(
input_size=input_size,
output_size=output_size,
config=config,
init_method=init_method,
bias=bias,
input_is_parallel=input_is_parallel,
skip_bias_add=skip_bias_add,
stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name
)
# flux params
self._forward_impl = linear_rs
self.flux_transpose_weight = getattr(self.config, "flux_transpose_weight", False)
self.previous_flux_params = (None,) * 5
self.fw_gemm_rs_op = None
self.bw_ag_gemm_op = None
def forward(self, input_):
"""Forward of RowParallelLinear
......@@ -934,45 +1130,86 @@ class RowParallelLinearPatch(torch.nn.Module):
else:
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
if self.use_flux:
self._forward_impl = linear_rs
elif not self.weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad = False
# flux kernels
forward_params = {
"input": input_parallel,
"weight": self.weight,
"bias": None if not self.use_flux or self.skip_bias_add else self.bias,
"gradient_accumulation_fusion": self.gradient_accumulation_fusion,
"allreduce_dgrad": allreduce_dgrad,
"sequence_parallel": False if not self.use_flux else self.sequence_parallel,
"grad_output_buffer": False,
}
if self.sequence_parallel:
sequence_len, batch_size, input_hidden_size = input_parallel.size()
output_hidden_size = self.weight.size(0)
world_size = get_tensor_model_parallel_world_size()
if self.use_flux:
forward_params.update({"transpose_weight": self.flux_transpose_weight})
current_flux_params = (
sequence_len,
batch_size,
input_hidden_size,
output_hidden_size,
input_parallel.dtype
)
output_parallel = self._forward_impl(**forward_params)
if self.use_flux:
return output_parallel, None if skip_bias_add else self.bias
if (
self.fw_gemm_rs_op is None
or current_flux_params != self.previous_flux_params
):
if not is_flux_min_version("1.1.0"):
self.fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
output_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
fuse_reduction=False
)
self.bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
input_hidden_size,
output_hidden_size,
input_parallel.dtype,
output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
self.previous_flux_params = current_flux_params
output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
allreduce_dgrad=False,
sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=None,
transpose_weight=self.flux_transpose_weight,
fw_gemm_rs_op=self.fw_gemm_rs_op,
bw_ag_gemm_op=self.bw_ag_gemm_op
)
# All-reduce across all the partitions.
if self.explicit_expert_comm:
assert self.skip_bias_add
output_ = output_parallel
elif self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
output_ = output_parallel
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = (output_ + self.bias) if self.bias is not None else output_
output_bias = None
output = (output_ + self.bias) if self.bias is not None else output_
else:
output = output_
output_bias = self.bias
return output, output_bias
def __repr__(self):
tp = self.input_size // self.input_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size_per_partition}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import os
import logging
from dataclasses import dataclass
from typing import Union, Optional, Literal
......@@ -11,6 +11,7 @@ from megatron.core.models.common.embeddings.language_model_embedding import Lang
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.module import MegatronModule
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module
......@@ -136,18 +137,22 @@ class MultiTokenPredictor(MegatronModule):
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=self.add_output_layer_bias,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
column_parallel_linear_impl = FluxColumnParallelLinear
else:
column_parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = column_parallel_linear_impl(
self.config.hidden_size,
self.vocab_size,
config=self.config,
init_method=self.config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
def forward(
self,
......
from contextlib import nullcontext
from typing import Optional
from functools import wraps
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import make_viewless_tensor
try:
from megatron.core.extensions.transformer_engine import TEDelayedScaling
HAVE_TE = True
except ImportError:
HAVE_TE = False
def transformer_block_init_wrapper(fn):
@wraps(fn)
......@@ -25,178 +8,8 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config']
self.move_final_norm_out_of_block = getattr(config, "num_nextn_predict_layers", 0) > 0
if getattr(config, "num_nextn_predict_layers", 0) > 0:
self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None
return wrapper
def transformer_block_forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor = None,
context_mask: Tensor = None,
rotary_pos_emb: Tensor = None,
rotary_pos_cos: Tensor = None,
rotary_pos_sin: Tensor = None,
attention_bias: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
sequence_len_offset: Tensor = None,
):
"""
Perform the forward pass through the transformer block.
This method handles the core computation of the transformer, including
self-attention, optional cross-attention, and feed-forward operations.
Args:
hidden_states (Tensor): Input tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
context (Tensor, optional): Context tensor for cross-attention.
context_mask (Tensor, optional): Mask for cross-attention context
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable
to [b, num_head, sq, skv], e.g. [1, 1, sq, skv].
Used as an alternative to apply attention mask for TE cuDNN attention.
inference_params (InferenceParams, optional): Parameters for inference-time
optimizations.
packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence
processing.
Returns:
Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
# Update the inference parameters with the current batch size in case it is variable
if inference_params and not self.training:
inference_params.current_batch_size = hidden_states.size(1)
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
if self.config.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
if self.config.fp8:
import transformer_engine # To keep out TE dependency when not training in fp8
if self.config.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif self.config.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
fp8_recipe = TEDelayedScaling(
config=self.config,
fp8_format=fp8_format,
override_linear_precision=(False, False, not self.config.fp8_wgrad),
)
fp8_group = None
if parallel_state.model_parallel_is_initialized():
fp8_group = parallel_state.get_amax_reduction_group(
with_context_parallel=True, tp_only_amax_red=self.tp_only_amax_red
)
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
)
else:
fp8_context = nullcontext()
with rng_context, fp8_context:
# Forward pass.
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)
else:
for l_no, layer in enumerate(self.layers):
with self.offload_context:
layer.use_cudagraph = True
if (len(self.cuda_graphs) == 0) or (not self.training):
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
else:
# CUDA graph replay for layer `l_no` and microbatch
# `self.current_microbatch`. TransformerEngine versions>=1.10
# allow keyword arguments with CUDA graph. However, CUDA graph
# acccepts only Tensor inputs and Tensor outputs. Hence,
# `inference_params` and `packed_seq_params` are excluded from
# input list while output is limited to `hidden_states`.
cg_index = self.current_microbatch % len(self.cuda_graphs[l_no])
assert not any(
[inference_params, packed_seq_params]
), "CUDA graph accepts only Tensor inputs."
optional_inputs = self.get_cuda_graph_optional_args(
attention_mask,
context,
context_mask,
rotary_pos_emb,
attention_bias,
inference_params,
packed_seq_params,
)
hidden_states = self.cuda_graphs[l_no][cg_index](
hidden_states, **optional_inputs
)
if (
torch.is_grad_enabled()
and self.config.cpu_offloading
and self.group_prefetch_offload_commit_async is not None
):
hidden_states = self.group_prefetch_offload_commit_async(hidden_states)
# Final layer norm.
if (not self.move_final_norm_out_of_block) and self.final_layernorm is not None:
hidden_states = self.final_layernorm(hidden_states)
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
hidden_states = make_viewless_tensor(
inp=hidden_states, requires_grad=True, keep_graph=True
)
return hidden_states
......@@ -26,9 +26,6 @@ class ExtraTransformerConfig:
##################
# flux
##################
use_flux: bool = False
"""If set, flux will be used in ColumnParallelLinear and RowParallelLinear"""
flux_transpose_weight: bool = False
......
import torch
from typing import List, Optional, Union
from importlib.metadata import version
from packaging.version import Version as PkgVersion
_flux_version = None
def get_flux_version():
"""Get flux version from __version__; if not available use pip's. Use caching."""
def get_flux_version_str():
import flux
if hasattr(flux, '__version__'):
return str(flux.__version__)
else:
return version("flux")
global _flux_version
if _flux_version is None:
_flux_version = PkgVersion(get_flux_version_str())
return _flux_version
def is_flux_min_version(version, check_equality=True):
"""Check if minimum version of `flux` is installed."""
if check_equality:
return get_flux_version() >= PkgVersion(version)
return get_flux_version() > PkgVersion(version)
def tensor_slide(
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
from torch import nn
class RMSNorm(torch.nn.Module):
def __init__(self,
dim: int,
eps: float = 1e-6,
sequence_parallel: bool = False,
config: dict = None):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
sequence_parallel (bool): Set to true if sequence parallelism is being used,
this marks the weights as needing to be allreduced.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
setattr(self.weight, 'sequence_parallel', sequence_parallel)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
import torch
from typing import Any, Callable, Dict, Optional, Tuple, Union
import lightop # rmsnorm_forward,rmsnorm_backward
from functools import partial
from megatron.core.utils import is_torch_min_version
if is_torch_min_version("2.4.0a0"):
custom_fwd = partial(torch.amp.custom_fwd, device_type="cuda")
custom_bwd = partial(torch.amp.custom_bwd, device_type="cuda")
else:
custom_fwd = torch.cuda.amp.custom_fwd
custom_bwd = torch.cuda.amp.custom_bwd
class _LightopRMSNorm(torch.autograd.Function):
""" 使用lightop实现rmsnorm"""
@staticmethod
# @custom_fwd
def forward(ctx,
inp: torch.Tensor,
weight: torch.Tensor,
ln_out: torch.Tensor,
eps: float,
is_grad_enabled):
output = lightop.rmsnorm_forward(inp, weight, ln_out, eps, training=True)# output = (output, weight)
rsigma = output[1]
if is_grad_enabled:
ctx.save_for_backward(inp, weight, rsigma)
return output[0]
@staticmethod
# @custom_bwd
def backward(ctx, grad_output):
inp, weight, rsigma = ctx.saved_tensors
dgrad, dgamma = lightop.rmsnorm_backward(grad_output, inp, rsigma, weight)
return dgrad, dgamma, None, None, None
class LightopRMSNorm(torch.nn.Module):
def __init__(self,
dim: int,
eps: float = 1e-6,):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
"""
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
# @no_torch_dynamo() # 动态torch._dynamo.disable
def forward(self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None):
if torch.is_grad_enabled():
fwd_fn = _LightopRMSNorm.apply
args = []
else:
fwd_fn = _LightopRMSNorm.forward
args = [None]
ln_out = torch.empty_like(inp, dtype=inp.dtype, memory_format=torch.contiguous_format)
args += (inp, self.weight, ln_out, self.eps, torch.is_grad_enabled())
out = fwd_fn(*args)
return out
......@@ -3,14 +3,21 @@ import torch.nn.functional as F
from megatron.training import get_args
from megatron.core import tensor_parallel
from megatron.legacy.model.enums import AttnType
from megatron.core.models.common.embeddings import apply_rotary_pos_emb
from megatron.legacy.model.module import MegatronModule
from megatron.legacy.model.transformer import ParallelMLP
from megatron.legacy.model.utils import (
erf_gelu,
openai_gelu,
)
try:
from einops import rearrange
except ImportError:
rearrange = None
class ParallelMLP(MegatronModule):
class ParallelMLPPatch(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
......@@ -74,7 +81,7 @@ class ParallelMLP(MegatronModule):
)
class ParallelAttention(MegatronModule):
class ParallelAttentionPatch(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
......
from megatron.training import get_args
from megatron.legacy.model import LayerNorm
from .rms_norm import RMSNorm, LightopRMSNorm
def get_norm(config):
args = get_args()
if args.normalization == "LayerNorm":
return LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=not config.persist_layer_norm,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
elif args.normalization == "RMSNorm":
if args.apply_layernorm_1p:
raise NotImplementedError('RMSNorm does not currently support the layernorm_1p formulation.')
return RMSNorm(dim=config.hidden_size,
eps=config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
elif args.normalization == "LightopRMSNorm":
return LightopRMSNorm(dim=config.hidden_size,
eps=config.layernorm_epsilon)
else:
raise Exception(f"unsupported norm type '{args.normalization}'.")
......@@ -51,6 +51,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
# Standard arguments.
parser = _add_network_size_args(parser)
parser = _add_extra_network_size_args(parser)
parser = _add_regularization_args(parser)
parser = _add_training_args(parser)
parser = _add_extra_training_args(parser)
......@@ -106,6 +107,18 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
return args
def _add_extra_network_size_args(parser):
# 删除原参数
remove_original_params(parser, ["normalization"])
# 重定义参数
group = parser.add_argument_group(title='extra network size args')
group.add_argument('--normalization', default='LayerNorm',
choices=['LayerNorm', 'RMSNorm', 'LightopRMSNorm'],
help='Which normalization technique to use.')
return parser
def _add_extra_distributed_args(parser):
group = parser.add_argument_group(title='extra distributed args')
group.add_argument('--rank', default=-1, type=int,
......@@ -169,9 +182,7 @@ def _add_mtp_args(parser):
def _add_flux_args(parser):
group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--use-flux', action='store_true', default=False,
help='If set, flux will be used in ColumnParallelLinear and RowParallelLinear')
group = parser.add_argument_group(title='flux args')
group.add_argument('--flux-transpose-weight', action='store_true', default=False,
help='Whether to transpose weight when using flux kernel')
return parser
#!/bin/bash
# set -eux
#export FLASH_ATTENTION_PRINT_PARAM=1
# Runs the "7B" parameter model
export HSA_FORCE_FINE_GRAIN_PCIE=1
export OMP_NUM_THREADS=1
export NCCL_P2P_LEVEL=PXB # SYS
#export HIP_ALLOC_INITIALIZE=0
# export GPU_MAX_HW_QUEUES=10
export NCCL_ALGO=Ring
export NCCL_NCHANNELS_PER_PEER=16
export NCCL_MIN_NCHANNELS=32 # 20
export NCCL_MAX_NCHANNELS=32 # 20
export NCCL_IB_TIMEOUT=22
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export NCCL_NET_GDR_LEVEL=7
export NCCL_NET_GDR_READ=1
export RCCL_SDMA_COPY_ENABLE=0
export NCCL_TOPO_FILE="/public/home/wangxj/Projects/rccl-test/rccl-tests-0204/topo-input.xml"
# export NCCL_TOPO_FILE="/workspace/rccl-test/rccl-tests-0204/topo-input.xml"
export GLOG_minloglevel=3 # 打印error级别的nccl日志
source /opt/dtk/env.sh
# 导入hipblaslt库
# export LD_LIBRARY_PATH=/data/hipblaslt-install-0904/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/data/hipblaslt-install-dtk-25.04-0212/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/public/home/wangxj/Downloads/blas/hipblaslt-install-dtk-25.04-0212/lib:$LD_LIBRARY_PATH
# 更新rocblas
# export LD_LIBRARY_PATH=/data/rocblas-install_qwen1211/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/data/rocblas-install_qwen1228/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/data/rocblas-install-0118-bf16/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/data/rocblas-install-0203-release/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/public/home/wangxj/Downloads/blas/rocblas-install-0331-release/lib:$LD_LIBRARY_PATH
# torch控制多流转单流
export ALLREDUCE_STREAM_WITH_COMPUTE=1
export SENDRECV_STREAM_WITH_COMPUTE=1
# prof采集添加同步, 避免卡顿
# export GPU_FLUSH_ON_EXECUTION=1
# export HIP_DIRECT_DISPATCH=0
# 采集rocblas size
# export ROCBLAS_LAYER=3
# export HIPBLASLT_LOG_LEVEL=3
# 采集 fa size
# export FLASH_ATTENTION_PRINT_PARAM=1
#增加编译缓存
export cache_size_limit=64
# lightop算子库
export PYTORCH_ROCM_ARCH='gfx906;gfx926;gfx936'
# CHECKPOINT_PATH=./Llama-2-7b-hf-to-meg-tp1-pp2 #CHECKPOINT_PATH=./tmp_7b #
SAVE_PATH=./tmp_7b
TENSORBOARD_LOGS_PATH=./tmp_7b #$2 #<Specify path>
DATA_PATH="/public/home/gmhtest_tmp/RedPajama-Data-1T-Sample/redpajama_text_document" #<Specify path and file prefix>_text_document
# DATA_PATH="/data/datasets/oscar-1GB-head/oscar-1GB_head-llama3.2_text_document" #<Specify path and file prefix>_text_document
GPT_MODEL_ARGS=(
--num-layers 80 #80 #80 #40 # 20 #
--hidden-size 8192
--ffn-hidden-size 22016 # 28672
--num-attention-heads 64
--max-position-embeddings 8192
--group-query-attention
--num-query-groups 8
--normalization RMSNorm
--position-embedding-type rope
--untie-embeddings-and-output-weights # 分开处理embed和输出权重, 增加灵活性
)
export NVTE_FLASH_ATTN=1 # 走cutlass
# export NVTE_FLASH_ATTN_TRITON=1 # 走triton_fa
# --transformer-impl transformer_engine # 走core用这两组参数
# --use-mcore-models
# --transformer-impl local # 走legacy用这两组参数
# --use-legacy-models
TRAINING_ARGS=(
--transformer-impl local # 走legacy用这两组参数
--use-legacy-models
--micro-batch-size 1
--global-batch-size 512 #32 #512 #256 # 64 #240 #60 #512 #64
--train-iters 300
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.95
--init-method-std 0.006
--clip-grad 1.0
--bf16
# --fp16 # 开启fp16需要指定loss-scale
# --loss-scale 1024
--use-distributed-optimizer
--disable-bias-linear
--attention-dropout 0
--hidden-dropout 0
# --no-gradient-accumulation-fusion
# --no-check-for-nan-in-loss-and-grad
--swiglu
--lr 3.0e-5
--lr-decay-style cosine
--min-lr 3.0e-6
--lr-warmup-iters 1
--ckpt-format torch
--ddp-average-in-collective # 在dp阶段通信中, 梯度或参数将被直接平均, 而不是先求和(到一个设备)再平均
# --recompute-activations
# --recompute-granularity full # 开启重计算降低显存增加耗时
# --recompute-num-layers 1 #0 #
# --recompute-method block
--overlap-grad-reduce # 重叠ddp grad reduce
# --tp-comm-overlap # tensor parallel comm和gemm重叠, 启动core
# --tp-comm-overlap-rs-dgrad # reduce-scatter和dgrad gemm重叠, 启动core
--use-flash-attn
)
# export TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1
# export TORCHINDUCTOR_BENCHMARK_FUSION=1
# export TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES=1
# export TORCHINDUCTOR_MAX_AUTOTUNE=1
# export TORCHINDUCTOR_CACHE_DIR=./cache
# --use-flash-attn-cutlass # cutlass fa
# --use-flash-attn-triton # triton fa
# --use-flash-attn-torch # torch fa
MODEL_PARALLEL_ARGS=(
--sequence-parallel
--tensor-model-parallel-size 4
--pipeline-model-parallel-size 8
--context-parallel-size 1
# --num-layers-per-virtual-pipeline-stage 1
# --microbatch-group-size-per-virtual-pipeline-stage 5
# --no-overlap-p2p-communication # 开启后
)
DATA_ARGS=(
--data-path $DATA_PATH
--seq-length 4096 #8192 #4096
--split 949,50,1
--tokenizer-type Llama2Tokenizer
--tokenizer-model /public/home/gmhtest_tmp/RedPajama-Data-1T-Sample/tokenizer.model
# --tokenizer-model /data/model_weights/llama3.2/tokenizer.model
)
EVAL_AND_LOGGING_ARGS=(
--log-interval 1
--log-throughput
--save-interval 500
--eval-interval 50
--eval-iters 3
--save $SAVE_PATH
--load $SAVE_PATH
--tensorboard-dir $TENSORBOARD_LOGS_PATH
)
# FINETUNE_ARGS=(
# # --finetune
# # --pretrained-checkpoint $CHECKPOINT_PATH
# --load $CHECKPOINT_PATH
# --no-load-optim
# --no-load-rng
# )
PROFILE_ARGS=(
--profile
--profile-step-start 4
--profile-step-end 5
--use-pytorch-profiler
--profile-ranks 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
--profile-dir prof_data
)
RANK=$OMPI_COMM_WORLD_RANK
LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
WORLD_SIZE=$OMPI_COMM_WORLD_SIZE
DIST_URL=${1}
DIST_PORT=34577
DISTRIBUTED_ARGS=(
--rank ${RANK}
--world-size ${WORLD_SIZE}
--local-rank ${LOCAL_RANK}
--dist-url tcp://${DIST_URL}:${DIST_PORT}
)
APP="python -u ../../pretrain_gpt.py \
${GPT_MODEL_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${DATA_ARGS[@]} \
${EVAL_AND_LOGGING_ARGS[@]} \
${DISTRIBUTED_ARGS[@]} \
"
# 开启profile
# ${PROFILE_ARGS[@]} \
# export HIP_VISIBLE_DEVICES=0,7 # # 4,5,6,7 #,
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # # 4,5,6,7 #,
# export CUDA_VISIBLE_DEVICES=4,5,6,7 # 0,1,2,3,
# ${APP}
case ${LOCAL_RANK} in
[0])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
numactl --cpunodebind=0 --membind=0 ${APP}
;;
[1])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
numactl --cpunodebind=1 --membind=1 ${APP}
;;
[2])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
numactl --cpunodebind=2 --membind=2 ${APP}
;;
[3])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
numactl --cpunodebind=3 --membind=3 ${APP}
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
;;
[4])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
numactl --cpunodebind=4 --membind=4 ${APP}
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
;;
[5])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
numactl --cpunodebind=5 --membind=5 ${APP}
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
;;
[6])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
numactl --cpunodebind=6 --membind=6 ${APP}
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
;;
[7])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
numactl --cpunodebind=7 --membind=7 ${APP}
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
;;
esac
\ No newline at end of file
#!/bin/bash
# set -eux
for para in $*
do
if [[ $para == --profiling* ]];then
profiling=${para#*=}
fi
done
CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )"
MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR}))
#default env
#export FLASH_ATTENTION_PRINT_PARAM=1
export HSA_FORCE_FINE_GRAIN_PCIE=1
export OMP_NUM_THREADS=1
export NCCL_P2P_LEVEL=PXB # SYS
# export GPU_MAX_HW_QUEUES=10
#export HIP_ALLOC_INITIALIZE=0
export CUDA_DEVICE_MAX_CONNECTIONS=1
# nccl env
export NCCL_ALGO=Ring
export NCCL_NCHANNELS_PER_PEER=16
export NCCL_MIN_NCHANNELS=32 # 20
export NCCL_MAX_NCHANNELS=32 # 20
export NCCL_IB_TIMEOUT=22
export NCCL_NET_GDR_LEVEL=7
export NCCL_NET_GDR_READ=1
export RCCL_SDMA_COPY_ENABLE=0
export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export NCCL_TOPO_FILE="/workspace/rccl-test/rccl-tests-0204/topo-input.xml"
export GLOG_minloglevel=3 # 打印error级别的nccl日志
source /opt/dtk/env.sh
# hipblaslt库
export LD_LIBRARY_PATH=/data/blas/hipblaslt-install-dtk-25.04-0212/lib:$LD_LIBRARY_PATH
# rocblas
export LD_LIBRARY_PATH=/data/blas/rocblas-install-0331-release/lib:$LD_LIBRARY_PATH
# torch控制多流转单流
export ALLREDUCE_STREAM_WITH_COMPUTE=1
export SENDRECV_STREAM_WITH_COMPUTE=1
#增加编译缓存
export cache_size_limit=64
# CHECKPOINT_PATH=./Llama-2-7b-hf-to-meg-tp1-pp2 #CHECKPOINT_PATH=./tmp_7b #
SAVE_PATH=./tmp_7b
TENSORBOARD_LOGS_PATH=./tmp_7b #$2 #<Specify path>
DATA_PATH="/data/datasets/oscar-1GB/oscar-1GB-llama2_text_document" #<Specify path and file prefix>_text_document
GPT_MODEL_ARGS=(
--num-layers 32
--hidden-size 4096
--ffn-hidden-size 11008
--num-attention-heads 32
--max-position-embeddings 4096
--normalization RMSNorm # LightopRMSNorm
--position-embedding-type rope # none #
--untie-embeddings-and-output-weights # 分开处理embed和输出权重, 增加灵活性
)
export NVTE_FLASH_ATTN=1 # 走cutlass
# export NVTE_FLASH_ATTN_TRITON=1 # 走triton_fa
# --transformer-impl transformer_engine # 走core用这两组参数
# --use-mcore-models
# --transformer-impl local # 走legacy用这两组参数
# --use-legacy-models
TRAINING_ARGS=(
--transformer-impl local # 走legacy用这两组参数
--use-legacy-models
--micro-batch-size 1
--global-batch-size 256 #256 #240 #60 #512 #64
--train-iters 50
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.95
--init-method-std 0.006
--clip-grad 1.0
--bf16
# --fp16 # 开启fp16需要指定loss-scale
# --loss-scale 1024
--use-distributed-optimizer
--disable-bias-linear
--attention-dropout 0
--hidden-dropout 0
# --no-gradient-accumulation-fusion
--swiglu
--lr 3.0e-5
--lr-decay-style cosine
--min-lr 3.0e-6
--lr-warmup-iters 1
--ckpt-format torch
--ddp-average-in-collective # 在dp阶段通信中, 梯度或参数将被直接平均, 而不是先求和(到一个设备)再平均
# --recompute-granularity full # 开启重计算降低显存增加耗时
# --recompute-num-layers 5 #0 #
# --recompute-method block
--overlap-grad-reduce # 重叠ddp grad reduce
# --tp-comm-overlap # tensor parallel comm和gemm重叠, 优化项未适配
# --tp-comm-overlap-rs-dgrad # reduce-scatter和dgrad gemm重叠
--use-flash-attn
)
# 使用torch fa的环境变量
# export TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1
# export TORCHINDUCTOR_BENCHMARK_FUSION=1
# export TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES=1
# export TORCHINDUCTOR_MAX_AUTOTUNE=1
# export TORCHINDUCTOR_CACHE_DIR=./cache
# --use-flash-attn-cutlass # cutlass fa
# --use-flash-attn-triton # triton fa
# --use-flash-attn-torch # torch fa
MODEL_PARALLEL_ARGS=(
--sequence-parallel
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 2
# --context-parallel-size 2
# --num-layers-per-virtual-pipeline-stage 4
# --microbatch-group-size-per-virtual-pipeline-stage 1
# --no-overlap-p2p-communication # 开启后
)
DATA_ARGS=(
--data-path $DATA_PATH
--seq-length 4096 #4096
--split 949,50,1
--tokenizer-type Llama2Tokenizer
--tokenizer-model /data/model_weights/llama2_7b_hf/tokenizer.model
)
EVAL_AND_LOGGING_ARGS=(
--log-throughput
--eval-iters 50
--log-interval 1
--save-interval 1000
--eval-interval 1000
--save $SAVE_PATH
--load $SAVE_PATH
--tensorboard-dir $TENSORBOARD_LOGS_PATH
)
# FINETUNE_ARGS=(
# # --finetune
# # --pretrained-checkpoint $CHECKPOINT_PATH
# --load $CHECKPOINT_PATH
# --no-load-optim
# --no-load-rng
# )
PROFILE_ARGS=(
--profile
--profile-step-start 4
--profile-step-end 5
--use-pytorch-profiler
--profile-ranks 0 1 2 3 4 5 6 7
--profile-dir prof_data
)
RANK=$OMPI_COMM_WORLD_RANK
LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
WORLD_SIZE=$OMPI_COMM_WORLD_SIZE
DIST_URL=${1}
DIST_PORT=34577
DISTRIBUTED_ARGS=(
--rank ${RANK}
--world-size ${WORLD_SIZE}
--local-rank ${LOCAL_RANK}
--dist-url tcp://${DIST_URL}:${DIST_PORT}
)
APP="python -u ${MEGATRON_PATH}/pretrain_gpt.py \
${GPT_MODEL_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${DATA_ARGS[@]} \
${EVAL_AND_LOGGING_ARGS[@]} \
${DISTRIBUTED_ARGS[@]} \
"
# 开启profile
# ${PROFILE_ARGS[@]} \
# export HIP_VISIBLE_DEVICES=0,7 # # 4,5,6,7 #,
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # # 4,5,6,7 #,
# export CUDA_VISIBLE_DEVICES=4,5,6,7 # 0,1,2,3,
# ${APP}
case ${LOCAL_RANK} in
[0])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
numactl --cpunodebind=0 --membind=0 ${APP}
;;
[1])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
numactl --cpunodebind=1 --membind=1 ${APP}
;;
[2])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
numactl --cpunodebind=2 --membind=2 ${APP}
;;
[3])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
numactl --cpunodebind=3 --membind=3 ${APP}
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
;;
[4])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
numactl --cpunodebind=4 --membind=4 ${APP}
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
;;
[5])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
numactl --cpunodebind=5 --membind=5 ${APP}
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
;;
[6])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
numactl --cpunodebind=6 --membind=6 ${APP}
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
;;
[7])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
numactl --cpunodebind=7 --membind=7 ${APP}
# hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
;;
esac
\ No newline at end of file
#!/bin/bash
# set -eux
#export FLASH_ATTENTION_PRINT_PARAM=1
# Runs the "7B" parameter model
export HSA_FORCE_FINE_GRAIN_PCIE=1
export OMP_NUM_THREADS=1
export NCCL_P2P_LEVEL=PXB # SYS
#export HIP_ALLOC_INITIALIZE=0
# export GPU_MAX_HW_QUEUES=10
export NCCL_ALGO=Ring
export NCCL_NCHANNELS_PER_PEER=16
export NCCL_MIN_NCHANNELS=32 # 20
export NCCL_MAX_NCHANNELS=32 # 20
export NCCL_IB_TIMEOUT=22
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export NCCL_NET_GDR_LEVEL=7
export NCCL_NET_GDR_READ=1
export RCCL_SDMA_COPY_ENABLE=0
export NCCL_TOPO_FILE="/public/home/wangxj/Projects/rccl-test/rccl-tests-0204/topo-input.xml"
# export NCCL_TOPO_FILE="/workspace/rccl-test/rccl-tests-0204/topo-input.xml"
export GLOG_minloglevel=3 # 打印error级别的nccl日志
source /opt/dtk/env.sh
# 导入hipblaslt库
# export LD_LIBRARY_PATH=/data/hipblaslt-install-0904/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/data/hipblaslt-install-dtk-25.04-0212/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/public/home/wangxj/Downloads/blas/hipblaslt-install-dtk-25.04-0212/lib:$LD_LIBRARY_PATH
# 更新rocblas
# export LD_LIBRARY_PATH=/data/rocblas-install_qwen1211/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/data/rocblas-install_qwen1228/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/data/rocblas-install-0118-bf16/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/data/rocblas-install-0203-release/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/public/home/wangxj/Downloads/blas/rocblas-install-0203-release/lib:$LD_LIBRARY_PATH
# torch控制多流转单流
export ALLREDUCE_STREAM_WITH_COMPUTE=1
export SENDRECV_STREAM_WITH_COMPUTE=1
# prof采集添加同步, 避免卡顿
# export GPU_FLUSH_ON_EXECUTION=1
# export HIP_DIRECT_DISPATCH=0
# 采集rocblas size
# export ROCBLAS_LAYER=3
# 采集 fa size
# export FLASH_ATTENTION_PRINT_PARAM=1
#增加编译缓存
export cache_size_limit=64
# CHECKPOINT_PATH=./Llama-2-7b-hf-to-meg-tp1-pp2 #CHECKPOINT_PATH=./tmp_7b #
SAVE_PATH=./tmp_7b
TENSORBOARD_LOGS_PATH=./tmp_7b #$2 #<Specify path>
DATA_PATH="/public/home/wangxj/Downloads/datasets/oscar-1GB-head/oscar-1GB_head-llama3.2_text_document" #<Specify path and file prefix>_text_document
# DATA_PATH="/data/datasets/oscar-1GB-head/oscar-1GB_head-llama3.2_text_document" #<Specify path and file prefix>_text_document
GPT_MODEL_ARGS=(
--num-layers 126 #96 #8 # 126
--hidden-size 16384
--ffn-hidden-size 53248
--num-attention-heads 128
--max-position-embeddings 16384
--group-query-attention
--num-query-groups 16
--normalization RMSNorm
--position-embedding-type rope
--untie-embeddings-and-output-weights # 分开处理embed和输出权重, 增加灵活性
)
export NVTE_FLASH_ATTN=1 # 走cutlass
# export NVTE_FLASH_ATTN_TRITON=1 # 走triton_fa
# --transformer-impl transformer_engine # 走core用这两组参数
# --use-mcore-models
# --transformer-impl local # 走legacy用这两组参数
# --use-legacy-models
TRAINING_ARGS=(
--transformer-impl transformer_engine # 走core用这两组参数
--use-mcore-models
--micro-batch-size 1
--global-batch-size 6912 # 252 #32 # 64 #240 #60 #512 #64
--train-iters 100
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.95
--init-method-std 0.006
--clip-grad 1.0
--bf16
# --fp16 # 开启fp16需要指定loss-scale
# --loss-scale 1024
--use-distributed-optimizer
--disable-bias-linear
--attention-dropout 0
--hidden-dropout 0
# --no-gradient-accumulation-fusion
--swiglu
--lr 3.0e-5
--lr-decay-style cosine
--min-lr 3.0e-6
--lr-warmup-iters 1
--ckpt-format torch
--ddp-average-in-collective # 在dp阶段通信中, 梯度或参数将被直接平均, 而不是先求和(到一个设备)再平均
# --recompute-granularity full # 开启重计算降低显存增加耗时
# --recompute-num-layers 5 #0 #
# --recompute-method block
--overlap-grad-reduce # 重叠ddp grad reduce
# --tp-comm-overlap # tensor parallel comm和gemm重叠, 优化项未适配
# --tp-comm-overlap-rs-dgrad # reduce-scatter和dgrad gemm重叠, 优化项未适配
--use-flash-attn-cutlass
)
# export TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1
# export TORCHINDUCTOR_BENCHMARK_FUSION=1
# export TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES=1
# export TORCHINDUCTOR_MAX_AUTOTUNE=1
# export TORCHINDUCTOR_CACHE_DIR=./cache
# --use-flash-attn-cutlass # cutlass fa
# --use-flash-attn-triton # triton fa
# --use-flash-attn-torch # torch fa
MODEL_PARALLEL_ARGS=(
--sequence-parallel
--tensor-model-parallel-size 8
--pipeline-model-parallel-size 18 # 7 layer/gpu
--context-parallel-size 2
)
DATA_ARGS=(
--data-path $DATA_PATH
--seq-length 4096 #4096
--split 949,50,1
--tokenizer-type Llama3Tokenizer
--tokenizer-model /public/home/wangxj/Downloads/model_weights/llama3.2/tokenizer.model
# --tokenizer-model /data/model_weights/llama3.2/tokenizer.model
)
EVAL_AND_LOGGING_ARGS=(
--log-interval 1
--log-throughput
--save-interval 1000
--eval-interval 1000
--save $SAVE_PATH
--load $SAVE_PATH
--eval-iters 10
--tensorboard-dir $TENSORBOARD_LOGS_PATH
)
# FINETUNE_ARGS=(
# # --finetune
# # --pretrained-checkpoint $CHECKPOINT_PATH
# --load $CHECKPOINT_PATH
# --no-load-optim
# --no-load-rng
# )
PROFILE_ARGS=(
--profile
--profile-step-start 4
--profile-step-end 5
--use-pytorch-profiler
--profile-ranks 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
--profile-dir prof_data
)
RANK=$OMPI_COMM_WORLD_RANK
LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
WORLD_SIZE=$OMPI_COMM_WORLD_SIZE
DIST_URL=${1}
DIST_PORT=34577
DISTRIBUTED_ARGS=(
--rank ${RANK}
--world-size ${WORLD_SIZE}
--local-rank ${LOCAL_RANK}
--dist-url tcp://${DIST_URL}:${DIST_PORT}
)
APP="python -u pretrain_gpt.py \
${GPT_MODEL_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${DATA_ARGS[@]} \
${EVAL_AND_LOGGING_ARGS[@]} \
${DISTRIBUTED_ARGS[@]} \
"
# 开启profile
# ${PROFILE_ARGS[@]} \
# export HIP_VISIBLE_DEVICES=0,7 # # 4,5,6,7 #,
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # # 4,5,6,7 #,
# export CUDA_VISIBLE_DEVICES=4,5,6,7 # 0,1,2,3,
${APP}
# case ${LOCAL_RANK} in
# [0])
# export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
# numactl --cpunodebind=0 --membind=0 ${APP}
# ;;
# [1])
# export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
# numactl --cpunodebind=1 --membind=1 ${APP}
# ;;
# [2])
# export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
# numactl --cpunodebind=2 --membind=2 ${APP}
# ;;
# [3])
# export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# numactl --cpunodebind=3 --membind=3 ${APP}
# # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
# ;;
# [4])
# export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# numactl --cpunodebind=4 --membind=4 ${APP}
# # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
# ;;
# [5])
# export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# numactl --cpunodebind=5 --membind=5 ${APP}
# # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
# ;;
# [6])
# export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# numactl --cpunodebind=6 --membind=6 ${APP}
# # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
# ;;
# [7])
# export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# numactl --cpunodebind=7 --membind=7 ${APP}
# # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP}
# ;;
# esac
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment