Commit 4754c307 authored by dongcl's avatar dongcl
Browse files

check flux version

parent b3371c31
......@@ -50,6 +50,8 @@ from megatron.core.tensor_parallel.layers import (
linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce
)
from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True
try:
......@@ -208,18 +210,19 @@ class AGLinear(torch.autograd.Function):
world_size = get_tensor_model_parallel_world_size()
if fw_ag_gemm_op is None:
fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_hidden_size,
input.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
if not is_flux_min_version("1.1.0"):
fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_hidden_size,
input.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
output = fw_ag_gemm_op.forward(
input.view(sequence_len * batch_size, -1),
......@@ -281,16 +284,17 @@ class AGLinear(torch.autograd.Function):
if bw_gemm_rs_op is None:
input_hidden_size = weight.size(-1)
bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
input_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
fuse_reduction=False
)
if not is_flux_min_version("1.1.0"):
bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
input_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
fuse_reduction=False
)
grad_input = bw_gemm_rs_op.forward(
grad_output.view(sequence_len * batch_size, -1),
......@@ -523,16 +527,17 @@ class LinearRS(torch.autograd.Function):
if sequence_parallel:
if fw_gemm_rs_op is None:
fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
output_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
fuse_reduction=False,
)
if not is_flux_min_version("1.1.0"):
fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
output_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
fuse_reduction=False,
)
output = fw_gemm_rs_op.forward(
input.view(sequence_len * batch_size, -1),
......@@ -592,18 +597,19 @@ class LinearRS(torch.autograd.Function):
input_hidden_size = weight.size(-1)
if bw_ag_gemm_op is None:
bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
input_hidden_size,
output_hidden_size,
grad_output.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
if not is_flux_min_version("1.1.0"):
bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
input_hidden_size,
output_hidden_size,
grad_output.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
grad_input = bw_ag_gemm_op.forward(
grad_output.view(sequence_len * batch_size, -1),
weight if transpose_weight else weight.t().contiguous(),
......@@ -961,29 +967,30 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
self.fw_ag_gemm_op is None
or current_flux_params != self.previous_flux_params
):
self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_hidden_size,
input_parallel.dtype,
output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
if not is_flux_min_version("1.1.0"):
self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_hidden_size,
input_parallel.dtype,
output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
input_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
fuse_reduction=False
)
self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
input_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
fuse_reduction=False
)
self.previous_flux_params = current_flux_params
......@@ -1143,29 +1150,30 @@ class FluxRowParallelLinear(RowParallelLinear):
self.fw_gemm_rs_op is None
or current_flux_params != self.previous_flux_params
):
self.fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
output_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
fuse_reduction=False
)
if not is_flux_min_version("1.1.0"):
self.fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
output_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
fuse_reduction=False
)
self.bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
input_hidden_size,
output_hidden_size,
input_parallel.dtype,
output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
self.bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
input_hidden_size,
output_hidden_size,
input_parallel.dtype,
output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
self.previous_flux_params = current_flux_params
......
import torch
from typing import List, Optional, Union
from importlib.metadata import version
from packaging.version import Version as PkgVersion
_flux_version = None
def get_flux_version():
"""Get flux version from __version__; if not available use pip's. Use caching."""
def get_flux_version_str():
import flux
if hasattr(flux, '__version__'):
return str(flux.__version__)
else:
return version("flux")
global _flux_version
if _flux_version is None:
_flux_version = PkgVersion(get_flux_version_str())
return _flux_version
def is_flux_min_version(version, check_equality=True):
"""Check if minimum version of `flux` is installed."""
if check_equality:
return get_flux_version() >= PkgVersion(version)
return get_flux_version() > PkgVersion(version)
def tensor_slide(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment