Commit 31e933a8 authored by dongcl's avatar dongcl
Browse files

Merge branch 'megatron_v0.11.0' of...

Merge branch 'megatron_v0.11.0' of http://developer.sourcefind.cn/codes/OpenDAS/dcu_megatron into megatron_v0.11.0
parents bf212e29 4754c307
...@@ -66,6 +66,10 @@ def unpermute( ...@@ -66,6 +66,10 @@ def unpermute(
): ):
``` ```
### 项目支持使用[flux kernel](http://10.6.10.68/dcutoolkit/deeplearing/flux)
在tp场景下,用户可以选择使用flux通算融合算子,获得更好的训练和推理性能。项目通过替换transformer engine方法集成flux,使用时需要设置环境变量USE_FLUX_OVERLAP=1,并设置transformer-impl为transformer_engine。
### 使用方式 ### 使用方式
在使用时,进入到examples目录下,有相关模型执行脚本,所用数据集请自行下载:https://r0ddbu55vzx.feishu.cn/drive/folder/ZxHHfCoX4lg75td2hTqcmiAin3g 在使用时,进入到examples目录下,有相关模型执行脚本,所用数据集请自行下载:https://r0ddbu55vzx.feishu.cn/drive/folder/ZxHHfCoX4lg75td2hTqcmiAin3g
``` ```
......
...@@ -50,6 +50,8 @@ from megatron.core.tensor_parallel.layers import ( ...@@ -50,6 +50,8 @@ from megatron.core.tensor_parallel.layers import (
linear_with_frozen_weight, linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce linear_with_grad_accumulation_and_async_allreduce
) )
from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True _grad_accum_fusion_available = True
try: try:
...@@ -208,18 +210,19 @@ class AGLinear(torch.autograd.Function): ...@@ -208,18 +210,19 @@ class AGLinear(torch.autograd.Function):
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
if fw_ag_gemm_op is None: if fw_ag_gemm_op is None:
fw_ag_gemm_op = flux.AGKernel( if not is_flux_min_version("1.1.0"):
get_tensor_model_parallel_group(), fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_group(),
sequence_len * batch_size * world_size, get_tensor_model_parallel_node_size(),
output_hidden_size, sequence_len * batch_size * world_size,
input_hidden_size, output_hidden_size,
input.dtype, input_hidden_size,
output_dtype=input.dtype, input.dtype,
transpose_weight=transpose_weight, output_dtype=input.dtype,
local_copy=False, transpose_weight=transpose_weight,
ring_mode=flux.AgRingMode.Auto, local_copy=False,
) ring_mode=flux.AgRingMode.Auto,
)
output = fw_ag_gemm_op.forward( output = fw_ag_gemm_op.forward(
input.view(sequence_len * batch_size, -1), input.view(sequence_len * batch_size, -1),
...@@ -281,16 +284,17 @@ class AGLinear(torch.autograd.Function): ...@@ -281,16 +284,17 @@ class AGLinear(torch.autograd.Function):
if bw_gemm_rs_op is None: if bw_gemm_rs_op is None:
input_hidden_size = weight.size(-1) input_hidden_size = weight.size(-1)
bw_gemm_rs_op = flux.GemmRS( if not is_flux_min_version("1.1.0"):
get_tensor_model_parallel_group(), bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_group(),
sequence_len * batch_size, get_tensor_model_parallel_node_size(),
input_hidden_size, sequence_len * batch_size,
input.dtype, input_hidden_size,
input.dtype, input.dtype,
transpose_weight=transpose_weight, input.dtype,
fuse_reduction=False transpose_weight=transpose_weight,
) fuse_reduction=False
)
grad_input = bw_gemm_rs_op.forward( grad_input = bw_gemm_rs_op.forward(
grad_output.view(sequence_len * batch_size, -1), grad_output.view(sequence_len * batch_size, -1),
...@@ -523,16 +527,17 @@ class LinearRS(torch.autograd.Function): ...@@ -523,16 +527,17 @@ class LinearRS(torch.autograd.Function):
if sequence_parallel: if sequence_parallel:
if fw_gemm_rs_op is None: if fw_gemm_rs_op is None:
fw_gemm_rs_op = flux.GemmRS( if not is_flux_min_version("1.1.0"):
get_tensor_model_parallel_group(), fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_group(),
sequence_len * batch_size, get_tensor_model_parallel_node_size(),
output_hidden_size, sequence_len * batch_size,
input.dtype, output_hidden_size,
input.dtype, input.dtype,
transpose_weight=transpose_weight, input.dtype,
fuse_reduction=False, transpose_weight=transpose_weight,
) fuse_reduction=False,
)
output = fw_gemm_rs_op.forward( output = fw_gemm_rs_op.forward(
input.view(sequence_len * batch_size, -1), input.view(sequence_len * batch_size, -1),
...@@ -592,18 +597,19 @@ class LinearRS(torch.autograd.Function): ...@@ -592,18 +597,19 @@ class LinearRS(torch.autograd.Function):
input_hidden_size = weight.size(-1) input_hidden_size = weight.size(-1)
if bw_ag_gemm_op is None: if bw_ag_gemm_op is None:
bw_ag_gemm_op = flux.AGKernel( if not is_flux_min_version("1.1.0"):
get_tensor_model_parallel_group(), bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_group(),
sequence_len * batch_size * world_size, get_tensor_model_parallel_node_size(),
input_hidden_size, sequence_len * batch_size * world_size,
output_hidden_size, input_hidden_size,
grad_output.dtype, output_hidden_size,
output_dtype=input.dtype, grad_output.dtype,
transpose_weight=transpose_weight, output_dtype=input.dtype,
local_copy=False, transpose_weight=transpose_weight,
ring_mode=flux.AgRingMode.Auto, local_copy=False,
) ring_mode=flux.AgRingMode.Auto,
)
grad_input = bw_ag_gemm_op.forward( grad_input = bw_ag_gemm_op.forward(
grad_output.view(sequence_len * batch_size, -1), grad_output.view(sequence_len * batch_size, -1),
weight if transpose_weight else weight.t().contiguous(), weight if transpose_weight else weight.t().contiguous(),
...@@ -961,29 +967,30 @@ class FluxColumnParallelLinear(ColumnParallelLinear): ...@@ -961,29 +967,30 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
self.fw_ag_gemm_op is None self.fw_ag_gemm_op is None
or current_flux_params != self.previous_flux_params or current_flux_params != self.previous_flux_params
): ):
self.fw_ag_gemm_op = flux.AGKernel( if not is_flux_min_version("1.1.0"):
get_tensor_model_parallel_group(), self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_group(),
sequence_len * batch_size * world_size, get_tensor_model_parallel_node_size(),
output_hidden_size, sequence_len * batch_size * world_size,
input_hidden_size, output_hidden_size,
input_parallel.dtype, input_hidden_size,
output_dtype=input_parallel.dtype, input_parallel.dtype,
transpose_weight=self.flux_transpose_weight, output_dtype=input_parallel.dtype,
local_copy=False, transpose_weight=self.flux_transpose_weight,
ring_mode=flux.AgRingMode.Auto, local_copy=False,
) ring_mode=flux.AgRingMode.Auto,
)
self.bw_gemm_rs_op = flux.GemmRS( self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size, sequence_len * batch_size * world_size,
input_hidden_size, input_hidden_size,
input_parallel.dtype, input_parallel.dtype,
input_parallel.dtype, input_parallel.dtype,
transpose_weight=self.flux_transpose_weight, transpose_weight=self.flux_transpose_weight,
fuse_reduction=False fuse_reduction=False
) )
self.previous_flux_params = current_flux_params self.previous_flux_params = current_flux_params
...@@ -1143,29 +1150,30 @@ class FluxRowParallelLinear(RowParallelLinear): ...@@ -1143,29 +1150,30 @@ class FluxRowParallelLinear(RowParallelLinear):
self.fw_gemm_rs_op is None self.fw_gemm_rs_op is None
or current_flux_params != self.previous_flux_params or current_flux_params != self.previous_flux_params
): ):
self.fw_gemm_rs_op = flux.GemmRS( if not is_flux_min_version("1.1.0"):
get_tensor_model_parallel_group(), self.fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_group(),
sequence_len * batch_size, get_tensor_model_parallel_node_size(),
output_hidden_size, sequence_len * batch_size,
input_parallel.dtype, output_hidden_size,
input_parallel.dtype, input_parallel.dtype,
transpose_weight=self.flux_transpose_weight, input_parallel.dtype,
fuse_reduction=False transpose_weight=self.flux_transpose_weight,
) fuse_reduction=False
)
self.bw_ag_gemm_op = flux.AGKernel( self.bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_node_size(),
sequence_len * batch_size, sequence_len * batch_size,
input_hidden_size, input_hidden_size,
output_hidden_size, output_hidden_size,
input_parallel.dtype, input_parallel.dtype,
output_dtype=input_parallel.dtype, output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight, transpose_weight=self.flux_transpose_weight,
local_copy=False, local_copy=False,
ring_mode=flux.AgRingMode.Auto, ring_mode=flux.AgRingMode.Auto,
) )
self.previous_flux_params = current_flux_params self.previous_flux_params = current_flux_params
......
import torch import torch
from typing import List, Optional, Union from typing import List, Optional, Union
from importlib.metadata import version
from packaging.version import Version as PkgVersion
_flux_version = None
def get_flux_version():
"""Get flux version from __version__; if not available use pip's. Use caching."""
def get_flux_version_str():
import flux
if hasattr(flux, '__version__'):
return str(flux.__version__)
else:
return version("flux")
global _flux_version
if _flux_version is None:
_flux_version = PkgVersion(get_flux_version_str())
return _flux_version
def is_flux_min_version(version, check_equality=True):
"""Check if minimum version of `flux` is installed."""
if check_equality:
return get_flux_version() >= PkgVersion(version)
return get_flux_version() > PkgVersion(version)
def tensor_slide( def tensor_slide(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment