Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
......@@ -79,6 +79,8 @@ using fp8e8m0 = uint8_t;
using int8 = int8_t;
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
#endif
template <typename T>
......@@ -240,7 +242,9 @@ class Tensor {
float scale() const {
if(scale_cpu_data_) {
NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!");
NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
|| (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING),
"Invalid scaling_mode!");
to_cpu();
return *scale_cpu_data_;
} else {
......@@ -254,6 +258,8 @@ class Tensor {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat8E4M3, "Invalid type!");
} else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
}
......@@ -267,6 +273,8 @@ class Tensor {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat8E4M3, "Invalid type!");
} else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
}
......@@ -321,10 +329,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127;
constexpr uint32_t FP32_MANTISSA_BITS = 23;
// [128,4] rowwise and [4,128] colwise alignment requirement
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
inline size_t divide_round_up(const size_t N, const size_t M) {
return (N - 1 + M) / M;
......@@ -473,12 +481,14 @@ void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
size_t N, float mismatch_rate_tol = 0.);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num,
const size_t scale_diff_abs_tolerance = 0,
const double abs_tolerable_mismatches_limit = 0,
const double rel_tolerable_mismatches_limit = 0);
template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num,
const size_t scale_diff_abs_tolerance = 0,
const double abs_tolerable_mismatches_limit = 0,
const double rel_tolerable_mismatches_limit = 0);
std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols,
const size_t block_size_rows, const size_t block_size_cols);
......@@ -501,6 +511,7 @@ const std::string& caseName(InputsFillCase type);
extern std::vector<DType> all_fp_types;
bool isFp8Type(DType type);
bool isFp4Type(DType type);
int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90;
......@@ -578,7 +589,7 @@ constexpr int32_t blackwellComputeCapability = 100;
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \
printf("dtype: %d\n", static_cast<int>(dtype)); \
NVTE_ERROR("Invalid type MARKED TEST."); \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
......@@ -597,7 +608,7 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
NVTE_ERROR("Invalid type MARKED TEST 2."); \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
......@@ -605,7 +616,7 @@ constexpr int32_t blackwellComputeCapability = 100;
using namespace transformer_engine; \
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type MARKED TEST 3."); \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
......@@ -630,5 +641,5 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
NVTE_ERROR("Invalid type MARKED TEST 4."); \
NVTE_ERROR("Invalid type."); \
}
......@@ -69,6 +69,34 @@ bool IsMulticastSupported(int device_id) {
return supported;
}
int GetDeviceComputeCapability(int device_id) {
int major{};
int minor{};
CHECK_CU(cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device_id));
CHECK_CU(cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device_id));
return major * 10 + minor;
}
template <typename T>
bool IsDTypeSupported(int /* device_id */) {
return true;
}
template <>
bool IsDTypeSupported<test::fp8e5m2>(int device_id) {
return GetDeviceComputeCapability(device_id) >= 89;
}
template <>
bool IsDTypeSupported<test::fp8e4m3>(int device_id) {
return GetDeviceComputeCapability(device_id) >= 89;
}
template <typename... Ts>
bool AllDTypesSupported(int device_id) {
return (IsDTypeSupported<Ts>(device_id) && ...);
}
template <typename T>
std::vector<T> CopyMatrix(const std::vector<T>& data, size_t mstart, size_t nstart, size_t msize,
size_t nsize, size_t ld) {
......@@ -161,6 +189,9 @@ class CommGemmFixure : public ::testing::TestWithParam<Params> {
template <typename AType, typename BType, typename DType, typename BiasType>
void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) {
if (!AllDTypesSupported<AType, BType, DType, BiasType>(rank_))
GTEST_SKIP() << "FP8 is not supported on device " << rank_;
cudaStream_t stream{};
NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
......
......@@ -17,14 +17,6 @@ from utils import assert_allclose, is_devices_enough
def generate_configs():
configs = []
if is_devices_enough(2):
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2")
)
if is_devices_enough(4):
configs.append(
pytest.param(
......@@ -32,10 +24,17 @@ def generate_configs():
(2, 2),
("dp", "tpsp"),
MeshResource(dp_resource="dp", tpsp_resource="tpsp"),
id=f"n4_dp2_tp2",
id="n4_dp2_tp2",
)
)
if is_devices_enough(2):
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2"),
)
return configs
......
This diff is collapsed.
This diff is collapsed.
......@@ -9,7 +9,7 @@ import numpy as np
from utils import pytest_parametrize_wrapper, is_devices_enough
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
def generate_mesh_configs():
......@@ -26,10 +26,10 @@ def generate_mesh_configs():
class TestMeshResource(unittest.TestCase):
def test_fp8_autocast_with_mesh_resource(self):
def test_autocast_with_mesh_resource(self):
for mesh_config in generate_mesh_configs():
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource):
with mesh, autocast(enabled=False, mesh_resource=mesh_resource):
self.assertEqual(mesh_resource, global_mesh_resource())
......@@ -15,7 +15,7 @@ from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import pytest_parametrize_wrapper
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
from transformer_engine.common import recipe
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available
......@@ -66,20 +66,19 @@ class TestDistributedLayernorm:
self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
# TODO(Phuong) is_dp_enabled = dp mesh axis size > 1
is_dp_enabled = mesh_resource.dp_resource is not None
is_tpsp_enabled = mesh_resource.tpsp_resource is not None
assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta
# TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp
weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1
allreduce_total_bytes = (
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
other_bytes = 0
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
# loss, 1 FP32
allreduce_total_bytes = 4 if is_dp_enabled else 0
# dgamma and dbeta
weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes += weight_count * shape[-1] * jax_dtype.itemsize
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
allreduce=allreduce_total_bytes * int(is_dp_enabled or is_tpsp_enabled),
allgather=0,
other=0,
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
......@@ -134,7 +133,7 @@ class TestDistributedLayernorm:
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))
......@@ -210,7 +209,7 @@ class TestDistributedLayernorm:
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
......
This diff is collapsed.
......@@ -15,7 +15,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax
DTYPES = [jnp.float16, jnp.bfloat16]
......@@ -102,7 +102,7 @@ class TestDistributedSoftmax:
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
with mesh, autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment