"vscode:/vscode.git/clone" did not exist on "3d1987d14997b80fba1f64ae8c9133ab3a96e44e"
Commit 4754c307 authored by dongcl's avatar dongcl
Browse files

check flux version

parent b3371c31
...@@ -50,6 +50,8 @@ from megatron.core.tensor_parallel.layers import ( ...@@ -50,6 +50,8 @@ from megatron.core.tensor_parallel.layers import (
linear_with_frozen_weight, linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce linear_with_grad_accumulation_and_async_allreduce
) )
from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True _grad_accum_fusion_available = True
try: try:
...@@ -208,18 +210,19 @@ class AGLinear(torch.autograd.Function): ...@@ -208,18 +210,19 @@ class AGLinear(torch.autograd.Function):
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
if fw_ag_gemm_op is None: if fw_ag_gemm_op is None:
fw_ag_gemm_op = flux.AGKernel( if not is_flux_min_version("1.1.0"):
get_tensor_model_parallel_group(), fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_group(),
sequence_len * batch_size * world_size, get_tensor_model_parallel_node_size(),
output_hidden_size, sequence_len * batch_size * world_size,
input_hidden_size, output_hidden_size,
input.dtype, input_hidden_size,
output_dtype=input.dtype, input.dtype,
transpose_weight=transpose_weight, output_dtype=input.dtype,
local_copy=False, transpose_weight=transpose_weight,
ring_mode=flux.AgRingMode.Auto, local_copy=False,
) ring_mode=flux.AgRingMode.Auto,
)
output = fw_ag_gemm_op.forward( output = fw_ag_gemm_op.forward(
input.view(sequence_len * batch_size, -1), input.view(sequence_len * batch_size, -1),
...@@ -281,16 +284,17 @@ class AGLinear(torch.autograd.Function): ...@@ -281,16 +284,17 @@ class AGLinear(torch.autograd.Function):
if bw_gemm_rs_op is None: if bw_gemm_rs_op is None:
input_hidden_size = weight.size(-1) input_hidden_size = weight.size(-1)
bw_gemm_rs_op = flux.GemmRS( if not is_flux_min_version("1.1.0"):
get_tensor_model_parallel_group(), bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_group(),
sequence_len * batch_size, get_tensor_model_parallel_node_size(),
input_hidden_size, sequence_len * batch_size,
input.dtype, input_hidden_size,
input.dtype, input.dtype,
transpose_weight=transpose_weight, input.dtype,
fuse_reduction=False transpose_weight=transpose_weight,
) fuse_reduction=False
)
grad_input = bw_gemm_rs_op.forward( grad_input = bw_gemm_rs_op.forward(
grad_output.view(sequence_len * batch_size, -1), grad_output.view(sequence_len * batch_size, -1),
...@@ -523,16 +527,17 @@ class LinearRS(torch.autograd.Function): ...@@ -523,16 +527,17 @@ class LinearRS(torch.autograd.Function):
if sequence_parallel: if sequence_parallel:
if fw_gemm_rs_op is None: if fw_gemm_rs_op is None:
fw_gemm_rs_op = flux.GemmRS( if not is_flux_min_version("1.1.0"):
get_tensor_model_parallel_group(), fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_group(),
sequence_len * batch_size, get_tensor_model_parallel_node_size(),
output_hidden_size, sequence_len * batch_size,
input.dtype, output_hidden_size,
input.dtype, input.dtype,
transpose_weight=transpose_weight, input.dtype,
fuse_reduction=False, transpose_weight=transpose_weight,
) fuse_reduction=False,
)
output = fw_gemm_rs_op.forward( output = fw_gemm_rs_op.forward(
input.view(sequence_len * batch_size, -1), input.view(sequence_len * batch_size, -1),
...@@ -592,18 +597,19 @@ class LinearRS(torch.autograd.Function): ...@@ -592,18 +597,19 @@ class LinearRS(torch.autograd.Function):
input_hidden_size = weight.size(-1) input_hidden_size = weight.size(-1)
if bw_ag_gemm_op is None: if bw_ag_gemm_op is None:
bw_ag_gemm_op = flux.AGKernel( if not is_flux_min_version("1.1.0"):
get_tensor_model_parallel_group(), bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_group(),
sequence_len * batch_size * world_size, get_tensor_model_parallel_node_size(),
input_hidden_size, sequence_len * batch_size * world_size,
output_hidden_size, input_hidden_size,
grad_output.dtype, output_hidden_size,
output_dtype=input.dtype, grad_output.dtype,
transpose_weight=transpose_weight, output_dtype=input.dtype,
local_copy=False, transpose_weight=transpose_weight,
ring_mode=flux.AgRingMode.Auto, local_copy=False,
) ring_mode=flux.AgRingMode.Auto,
)
grad_input = bw_ag_gemm_op.forward( grad_input = bw_ag_gemm_op.forward(
grad_output.view(sequence_len * batch_size, -1), grad_output.view(sequence_len * batch_size, -1),
weight if transpose_weight else weight.t().contiguous(), weight if transpose_weight else weight.t().contiguous(),
...@@ -961,29 +967,30 @@ class FluxColumnParallelLinear(ColumnParallelLinear): ...@@ -961,29 +967,30 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
self.fw_ag_gemm_op is None self.fw_ag_gemm_op is None
or current_flux_params != self.previous_flux_params or current_flux_params != self.previous_flux_params
): ):
self.fw_ag_gemm_op = flux.AGKernel( if not is_flux_min_version("1.1.0"):
get_tensor_model_parallel_group(), self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_group(),
sequence_len * batch_size * world_size, get_tensor_model_parallel_node_size(),
output_hidden_size, sequence_len * batch_size * world_size,
input_hidden_size, output_hidden_size,
input_parallel.dtype, input_hidden_size,
output_dtype=input_parallel.dtype, input_parallel.dtype,
transpose_weight=self.flux_transpose_weight, output_dtype=input_parallel.dtype,
local_copy=False, transpose_weight=self.flux_transpose_weight,
ring_mode=flux.AgRingMode.Auto, local_copy=False,
) ring_mode=flux.AgRingMode.Auto,
)
self.bw_gemm_rs_op = flux.GemmRS( self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size, sequence_len * batch_size * world_size,
input_hidden_size, input_hidden_size,
input_parallel.dtype, input_parallel.dtype,
input_parallel.dtype, input_parallel.dtype,
transpose_weight=self.flux_transpose_weight, transpose_weight=self.flux_transpose_weight,
fuse_reduction=False fuse_reduction=False
) )
self.previous_flux_params = current_flux_params self.previous_flux_params = current_flux_params
...@@ -1143,29 +1150,30 @@ class FluxRowParallelLinear(RowParallelLinear): ...@@ -1143,29 +1150,30 @@ class FluxRowParallelLinear(RowParallelLinear):
self.fw_gemm_rs_op is None self.fw_gemm_rs_op is None
or current_flux_params != self.previous_flux_params or current_flux_params != self.previous_flux_params
): ):
self.fw_gemm_rs_op = flux.GemmRS( if not is_flux_min_version("1.1.0"):
get_tensor_model_parallel_group(), self.fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_group(),
sequence_len * batch_size, get_tensor_model_parallel_node_size(),
output_hidden_size, sequence_len * batch_size,
input_parallel.dtype, output_hidden_size,
input_parallel.dtype, input_parallel.dtype,
transpose_weight=self.flux_transpose_weight, input_parallel.dtype,
fuse_reduction=False transpose_weight=self.flux_transpose_weight,
) fuse_reduction=False
)
self.bw_ag_gemm_op = flux.AGKernel( self.bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_node_size(),
sequence_len * batch_size, sequence_len * batch_size,
input_hidden_size, input_hidden_size,
output_hidden_size, output_hidden_size,
input_parallel.dtype, input_parallel.dtype,
output_dtype=input_parallel.dtype, output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight, transpose_weight=self.flux_transpose_weight,
local_copy=False, local_copy=False,
ring_mode=flux.AgRingMode.Auto, ring_mode=flux.AgRingMode.Auto,
) )
self.previous_flux_params = current_flux_params self.previous_flux_params = current_flux_params
......
import torch import torch
from typing import List, Optional, Union from typing import List, Optional, Union
from importlib.metadata import version
from packaging.version import Version as PkgVersion
_flux_version = None
def get_flux_version():
"""Get flux version from __version__; if not available use pip's. Use caching."""
def get_flux_version_str():
import flux
if hasattr(flux, '__version__'):
return str(flux.__version__)
else:
return version("flux")
global _flux_version
if _flux_version is None:
_flux_version = PkgVersion(get_flux_version_str())
return _flux_version
def is_flux_min_version(version, check_equality=True):
"""Check if minimum version of `flux` is installed."""
if check_equality:
return get_flux_version() >= PkgVersion(version)
return get_flux_version() > PkgVersion(version)
def tensor_slide( def tensor_slide(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment