Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
...@@ -79,6 +79,8 @@ using fp8e8m0 = uint8_t; ...@@ -79,6 +79,8 @@ using fp8e8m0 = uint8_t;
using int8 = int8_t; using int8 = int8_t;
#if FP4_TYPE_SUPPORTED #if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1; using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
#endif #endif
template <typename T> template <typename T>
...@@ -240,7 +242,9 @@ class Tensor { ...@@ -240,7 +242,9 @@ class Tensor {
float scale() const { float scale() const {
if(scale_cpu_data_) { if(scale_cpu_data_) {
NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
|| (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING),
"Invalid scaling_mode!");
to_cpu(); to_cpu();
return *scale_cpu_data_; return *scale_cpu_data_;
} else { } else {
...@@ -254,6 +258,8 @@ class Tensor { ...@@ -254,6 +258,8 @@ class Tensor {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat8E4M3, "Invalid type!");
} else { } else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
} }
...@@ -267,6 +273,8 @@ class Tensor { ...@@ -267,6 +273,8 @@ class Tensor {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat8E4M3, "Invalid type!");
} else { } else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!"); NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
} }
...@@ -321,10 +329,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127; ...@@ -321,10 +329,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127;
constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_MANTISSA_BITS = 23;
// [128,4] rowwise and [4,128] colwise alignment requirement // [128,4] rowwise and [4,128] colwise alignment requirement
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_colwise = 4; constexpr size_t scale_tensor_alignment_Y_colwise = 4;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
inline size_t divide_round_up(const size_t N, const size_t M) { inline size_t divide_round_up(const size_t N, const size_t M) {
return (N - 1 + M) / M; return (N - 1 + M) / M;
...@@ -473,13 +481,15 @@ void compareResults(const std::string &name, const float test, const float ref, ...@@ -473,13 +481,15 @@ void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8); double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
size_t N, float mismatch_rate_tol = 0.); size_t N, float mismatch_rate_tol = 0.);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride, const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, size_t& mismatches_num,
const size_t scale_diff_abs_tolerance = 0, const size_t scale_diff_abs_tolerance = 0,
const double abs_tolerable_mismatches_limit = 0, const double abs_tolerable_mismatches_limit = 0,
const double rel_tolerable_mismatches_limit = 0); const double rel_tolerable_mismatches_limit = 0);
std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols, std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols,
const size_t block_size_rows, const size_t block_size_cols); const size_t block_size_rows, const size_t block_size_cols);
...@@ -501,6 +511,7 @@ const std::string& caseName(InputsFillCase type); ...@@ -501,6 +511,7 @@ const std::string& caseName(InputsFillCase type);
extern std::vector<DType> all_fp_types; extern std::vector<DType> all_fp_types;
bool isFp8Type(DType type); bool isFp8Type(DType type);
bool isFp4Type(DType type);
int32_t getDeviceComputeCapability(); int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90; constexpr int32_t hopperComputeCapability = 90;
...@@ -578,7 +589,7 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -578,7 +589,7 @@ constexpr int32_t blackwellComputeCapability = 100;
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \ default: \
printf("dtype: %d\n", static_cast<int>(dtype)); \ printf("dtype: %d\n", static_cast<int>(dtype)); \
NVTE_ERROR("Invalid type MARKED TEST."); \ NVTE_ERROR("Invalid type."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
...@@ -597,7 +608,7 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -597,7 +608,7 @@ constexpr int32_t blackwellComputeCapability = 100;
} \ } \
break; \ break; \
default: \ default: \
NVTE_ERROR("Invalid type MARKED TEST 2."); \ NVTE_ERROR("Invalid type."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
...@@ -605,7 +616,7 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -605,7 +616,7 @@ constexpr int32_t blackwellComputeCapability = 100;
using namespace transformer_engine; \ using namespace transformer_engine; \
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \ SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
default: \ default: \
NVTE_ERROR("Invalid type MARKED TEST 3."); \ NVTE_ERROR("Invalid type."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
...@@ -630,5 +641,5 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -630,5 +641,5 @@ constexpr int32_t blackwellComputeCapability = 100;
} \ } \
break; \ break; \
default: \ default: \
NVTE_ERROR("Invalid type MARKED TEST 4."); \ NVTE_ERROR("Invalid type."); \
} }
...@@ -69,6 +69,34 @@ bool IsMulticastSupported(int device_id) { ...@@ -69,6 +69,34 @@ bool IsMulticastSupported(int device_id) {
return supported; return supported;
} }
int GetDeviceComputeCapability(int device_id) {
int major{};
int minor{};
CHECK_CU(cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device_id));
CHECK_CU(cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device_id));
return major * 10 + minor;
}
template <typename T>
bool IsDTypeSupported(int /* device_id */) {
return true;
}
template <>
bool IsDTypeSupported<test::fp8e5m2>(int device_id) {
return GetDeviceComputeCapability(device_id) >= 89;
}
template <>
bool IsDTypeSupported<test::fp8e4m3>(int device_id) {
return GetDeviceComputeCapability(device_id) >= 89;
}
template <typename... Ts>
bool AllDTypesSupported(int device_id) {
return (IsDTypeSupported<Ts>(device_id) && ...);
}
template <typename T> template <typename T>
std::vector<T> CopyMatrix(const std::vector<T>& data, size_t mstart, size_t nstart, size_t msize, std::vector<T> CopyMatrix(const std::vector<T>& data, size_t mstart, size_t nstart, size_t msize,
size_t nsize, size_t ld) { size_t nsize, size_t ld) {
...@@ -161,6 +189,9 @@ class CommGemmFixure : public ::testing::TestWithParam<Params> { ...@@ -161,6 +189,9 @@ class CommGemmFixure : public ::testing::TestWithParam<Params> {
template <typename AType, typename BType, typename DType, typename BiasType> template <typename AType, typename BType, typename DType, typename BiasType>
void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) { void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) {
if (!AllDTypesSupported<AType, BType, DType, BiasType>(rank_))
GTEST_SKIP() << "FP8 is not supported on device " << rank_;
cudaStream_t stream{}; cudaStream_t stream{};
NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
......
...@@ -17,14 +17,6 @@ from utils import assert_allclose, is_devices_enough ...@@ -17,14 +17,6 @@ from utils import assert_allclose, is_devices_enough
def generate_configs(): def generate_configs():
configs = [] configs = []
if is_devices_enough(2):
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2")
)
if is_devices_enough(4): if is_devices_enough(4):
configs.append( configs.append(
pytest.param( pytest.param(
...@@ -32,10 +24,17 @@ def generate_configs(): ...@@ -32,10 +24,17 @@ def generate_configs():
(2, 2), (2, 2),
("dp", "tpsp"), ("dp", "tpsp"),
MeshResource(dp_resource="dp", tpsp_resource="tpsp"), MeshResource(dp_resource="dp", tpsp_resource="tpsp"),
id=f"n4_dp2_tp2", id="n4_dp2_tp2",
) )
) )
if is_devices_enough(2):
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2"),
)
return configs return configs
......
This diff is collapsed.
This diff is collapsed.
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
from utils import pytest_parametrize_wrapper, is_devices_enough from utils import pytest_parametrize_wrapper, is_devices_enough
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import autocast
def generate_mesh_configs(): def generate_mesh_configs():
...@@ -26,10 +26,10 @@ def generate_mesh_configs(): ...@@ -26,10 +26,10 @@ def generate_mesh_configs():
class TestMeshResource(unittest.TestCase): class TestMeshResource(unittest.TestCase):
def test_fp8_autocast_with_mesh_resource(self): def test_autocast_with_mesh_resource(self):
for mesh_config in generate_mesh_configs(): for mesh_config in generate_mesh_configs():
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_axes) mesh = jax.sharding.Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource): with mesh, autocast(enabled=False, mesh_resource=mesh_resource):
self.assertEqual(mesh_resource, global_mesh_resource()) self.assertEqual(mesh_resource, global_mesh_resource())
...@@ -15,7 +15,7 @@ from distributed_test_base import generate_configs, generate_collectives_count ...@@ -15,7 +15,7 @@ from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops from distributed_test_base import compare_ops
from utils import pytest_parametrize_wrapper from utils import pytest_parametrize_wrapper
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import autocast
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available
...@@ -66,20 +66,19 @@ class TestDistributedLayernorm: ...@@ -66,20 +66,19 @@ class TestDistributedLayernorm:
self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe
): ):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype) jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
# TODO(Phuong) is_dp_enabled = dp mesh axis size > 1
is_dp_enabled = mesh_resource.dp_resource is not None is_dp_enabled = mesh_resource.dp_resource is not None
is_tpsp_enabled = mesh_resource.tpsp_resource is not None
assert ln_type in ["layernorm", "rmsnorm"] assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32 # loss, 1 FP32
# for loss, dgamma and dbeta allreduce_total_bytes = 4 if is_dp_enabled else 0
# TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp # dgamma and dbeta
weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1 weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes = ( allreduce_total_bytes += weight_count * shape[-1] * jax_dtype.itemsize
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
other_bytes = 0
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count( return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes allreduce=allreduce_total_bytes * int(is_dp_enabled or is_tpsp_enabled),
allgather=0,
other=0,
) )
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
...@@ -134,7 +133,7 @@ class TestDistributedLayernorm: ...@@ -134,7 +133,7 @@ class TestDistributedLayernorm:
) )
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec)) beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))
...@@ -210,7 +209,7 @@ class TestDistributedLayernorm: ...@@ -210,7 +209,7 @@ class TestDistributedLayernorm:
) )
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
......
This diff is collapsed.
...@@ -15,7 +15,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec ...@@ -15,7 +15,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax from transformer_engine.jax.softmax import SoftmaxType, softmax
DTYPES = [jnp.float16, jnp.bfloat16] DTYPES = [jnp.float16, jnp.bfloat16]
...@@ -102,7 +102,7 @@ class TestDistributedSoftmax: ...@@ -102,7 +102,7 @@ class TestDistributedSoftmax:
collective_count_ref = self.generate_collectives_count_ref() collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))
......
This diff is collapsed.
This diff is collapsed.
...@@ -28,7 +28,7 @@ from transformer_engine.jax.quantize import ( ...@@ -28,7 +28,7 @@ from transformer_engine.jax.quantize import (
is_fp8_available, is_fp8_available,
update_collections, update_collections,
TensorSource, TensorSource,
fp8_autocast, autocast,
) )
from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.sharding import MeshResource
...@@ -507,14 +507,14 @@ class BaseTester: ...@@ -507,14 +507,14 @@ class BaseTester:
"""Test normal datatype forward""" """Test normal datatype forward"""
# Ensure FP8 disabled. # Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device # Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=False, mesh_resource=MeshResource()): with autocast(enabled=False, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype) self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs): def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward""" """Test normal datatype backward"""
# Ensure FP8 disabled. # Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device # Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=False, mesh_resource=MeshResource()): with autocast(enabled=False, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype) self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -522,7 +522,7 @@ class BaseTester: ...@@ -522,7 +522,7 @@ class BaseTester:
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test forward with fp8 enabled""" """Test forward with fp8 enabled"""
# Empty MeshResource is used as we are running on a single device # Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -530,7 +530,7 @@ class BaseTester: ...@@ -530,7 +530,7 @@ class BaseTester:
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test backward with fp8 enabled""" """Test backward with fp8 enabled"""
# Empty MeshResource is used as we are running on a single device # Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
......
...@@ -1544,6 +1544,12 @@ def dtype_tols( ...@@ -1544,6 +1544,12 @@ def dtype_tols(
rtol = eps_relaxed rtol = eps_relaxed
if atol is None: if atol is None:
atol = max(ulp, eps_relaxed) atol = max(ulp, eps_relaxed)
# Manually set tols for nvfp4
if dtype == jnp.float4_e2m1fn:
atol = 0.05
rtol = 0.1
return {"rtol": rtol, "atol": atol} return {"rtol": rtol, "atol": atol}
......
This diff is collapsed.
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
"""Unit tests for context parallel utils.""" """Unit tests for context parallel utils."""
import torch import torch
import unittest import unittest
from typing import Tuple
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_batch_on_this_cp_rank, get_batch_on_this_cp_rank,
pad_thd_sequences_for_cp, pad_thd_sequences_for_cp,
......
This diff is collapsed.
This diff is collapsed.
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
import torch import torch
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch import Float8Tensor, Float8Quantizer
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment