Commit 4754c307 authored by dongcl's avatar dongcl
Browse files

check flux version

parent b3371c31
...@@ -50,6 +50,8 @@ from megatron.core.tensor_parallel.layers import ( ...@@ -50,6 +50,8 @@ from megatron.core.tensor_parallel.layers import (
linear_with_frozen_weight, linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce linear_with_grad_accumulation_and_async_allreduce
) )
from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True _grad_accum_fusion_available = True
try: try:
...@@ -208,6 +210,7 @@ class AGLinear(torch.autograd.Function): ...@@ -208,6 +210,7 @@ class AGLinear(torch.autograd.Function):
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
if fw_ag_gemm_op is None: if fw_ag_gemm_op is None:
if not is_flux_min_version("1.1.0"):
fw_ag_gemm_op = flux.AGKernel( fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_node_size(),
...@@ -281,6 +284,7 @@ class AGLinear(torch.autograd.Function): ...@@ -281,6 +284,7 @@ class AGLinear(torch.autograd.Function):
if bw_gemm_rs_op is None: if bw_gemm_rs_op is None:
input_hidden_size = weight.size(-1) input_hidden_size = weight.size(-1)
if not is_flux_min_version("1.1.0"):
bw_gemm_rs_op = flux.GemmRS( bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_node_size(),
...@@ -523,6 +527,7 @@ class LinearRS(torch.autograd.Function): ...@@ -523,6 +527,7 @@ class LinearRS(torch.autograd.Function):
if sequence_parallel: if sequence_parallel:
if fw_gemm_rs_op is None: if fw_gemm_rs_op is None:
if not is_flux_min_version("1.1.0"):
fw_gemm_rs_op = flux.GemmRS( fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_node_size(),
...@@ -592,6 +597,7 @@ class LinearRS(torch.autograd.Function): ...@@ -592,6 +597,7 @@ class LinearRS(torch.autograd.Function):
input_hidden_size = weight.size(-1) input_hidden_size = weight.size(-1)
if bw_ag_gemm_op is None: if bw_ag_gemm_op is None:
if not is_flux_min_version("1.1.0"):
bw_ag_gemm_op = flux.AGKernel( bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_node_size(),
...@@ -961,6 +967,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear): ...@@ -961,6 +967,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
self.fw_ag_gemm_op is None self.fw_ag_gemm_op is None
or current_flux_params != self.previous_flux_params or current_flux_params != self.previous_flux_params
): ):
if not is_flux_min_version("1.1.0"):
self.fw_ag_gemm_op = flux.AGKernel( self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_node_size(),
...@@ -1143,6 +1150,7 @@ class FluxRowParallelLinear(RowParallelLinear): ...@@ -1143,6 +1150,7 @@ class FluxRowParallelLinear(RowParallelLinear):
self.fw_gemm_rs_op is None self.fw_gemm_rs_op is None
or current_flux_params != self.previous_flux_params or current_flux_params != self.previous_flux_params
): ):
if not is_flux_min_version("1.1.0"):
self.fw_gemm_rs_op = flux.GemmRS( self.fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(), get_tensor_model_parallel_node_size(),
......
import torch import torch
from typing import List, Optional, Union from typing import List, Optional, Union
from importlib.metadata import version
from packaging.version import Version as PkgVersion
_flux_version = None
def get_flux_version():
"""Get flux version from __version__; if not available use pip's. Use caching."""
def get_flux_version_str():
import flux
if hasattr(flux, '__version__'):
return str(flux.__version__)
else:
return version("flux")
global _flux_version
if _flux_version is None:
_flux_version = PkgVersion(get_flux_version_str())
return _flux_version
def is_flux_min_version(version, check_equality=True):
"""Check if minimum version of `flux` is installed."""
if check_equality:
return get_flux_version() >= PkgVersion(version)
return get_flux_version() > PkgVersion(version)
def tensor_slide( def tensor_slide(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment