Unverified Commit a5694f26 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

Add separate RNG states for column-wise quantization with Stochastic Rounding (#2487)



* Add separate RNG states for columnwise quantization with Stochastic Rounding
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Fix single tensor path
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
parent 93c5c65b
...@@ -761,8 +761,16 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, ...@@ -761,8 +761,16 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,
} }
// Stochastic rounding // Stochastic rounding
// When both rowwise and columnwise quantization are used,
// we need separate RNG states for each to ensure they use different random numbers.
std::vector<TensorWrapper> te_rng_state_list; std::vector<TensorWrapper> te_rng_state_list;
std::vector<TensorWrapper> te_rng_state_columnwise_list;
std::vector<QuantizationConfigWrapper> quant_config_columnwise_list;
at::Tensor rng_states_tensor; at::Tensor rng_states_tensor;
at::Tensor rng_states_columnwise_tensor;
const bool need_separate_columnwise_rng =
quantizer.stochastic_rounding && quantizer.with_rht && quantizer.columnwise_usage;
if (quantizer.stochastic_rounding) { if (quantizer.stochastic_rounding) {
// TODO(zhongbo): remove the for loop of generating rng states with a single call // TODO(zhongbo): remove the for loop of generating rng states with a single call
// with rng_elts_per_thread = 1024 * num_tensors // with rng_elts_per_thread = 1024 * num_tensors
...@@ -770,9 +778,18 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, ...@@ -770,9 +778,18 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,
const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened
auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
rng_states_tensor = torch::empty({static_cast<int64_t>(2 * num_tensors)}, opts); rng_states_tensor = torch::empty({static_cast<int64_t>(2 * num_tensors)}, opts);
// Allocate columnwise RNG resources when separate RNG is needed
if (need_separate_columnwise_rng) {
rng_states_columnwise_tensor = torch::empty({static_cast<int64_t>(2 * num_tensors)}, opts);
for (size_t i = 0; i < num_tensors; ++i) {
quant_config_columnwise_list.emplace_back(QuantizationConfigWrapper());
}
}
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
// Generate RNG state for rowwise quantization
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
int64_t *rng_state_ptr = static_cast<int64_t *>(rng_states_tensor.data_ptr()) + i * 2; int64_t *rng_state_ptr = static_cast<int64_t *>(rng_states_tensor.data_ptr()) + i * 2;
philox_unpack(philox_args, rng_state_ptr); philox_unpack(philox_args, rng_state_ptr);
...@@ -780,6 +797,18 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, ...@@ -780,6 +797,18 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,
static_cast<void *>(rng_state_ptr), std::vector<size_t>{2}, DType::kInt64)); static_cast<void *>(rng_state_ptr), std::vector<size_t>{2}, DType::kInt64));
quant_config_list[i].set_rng_state(te_rng_state_list[i].data()); quant_config_list[i].set_rng_state(te_rng_state_list[i].data());
quant_config_list[i].set_stochastic_rounding(true); quant_config_list[i].set_stochastic_rounding(true);
// Generate separate RNG state for columnwise quantization
if (need_separate_columnwise_rng) {
at::PhiloxCudaState philox_args_columnwise = init_philox_state(gen, rng_elts_per_thread);
int64_t *rng_state_columnwise_ptr =
static_cast<int64_t *>(rng_states_columnwise_tensor.data_ptr()) + i * 2;
philox_unpack(philox_args_columnwise, rng_state_columnwise_ptr);
te_rng_state_columnwise_list.push_back(makeTransformerEngineTensor(
static_cast<void *>(rng_state_columnwise_ptr), std::vector<size_t>{2}, DType::kInt64));
quant_config_columnwise_list[i].set_rng_state(te_rng_state_columnwise_list[i].data());
quant_config_columnwise_list[i].set_stochastic_rounding(true);
}
} }
} }
...@@ -864,9 +893,12 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, ...@@ -864,9 +893,12 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,
out_columnwise_amax.shape); out_columnwise_amax.shape);
// RHT + NVFP4 quantize kernel // RHT + NVFP4 quantize kernel
// Use separate RNG state for columnwise to ensure different random numbers than rowwise
auto &columnwise_quant_config =
need_separate_columnwise_rng ? quant_config_columnwise_list[i] : quant_config_list[i];
nvte_hadamard_transform_cast_fusion_columnwise(input_list[i].data(), out_transpose.data(), nvte_hadamard_transform_cast_fusion_columnwise(input_list[i].data(), out_transpose.data(),
rht_matrix_nvte.data(), rht_matrix_nvte.data(),
quant_config_list[i], stream); columnwise_quant_config, stream);
} }
} }
}); });
......
...@@ -1468,17 +1468,37 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou ...@@ -1468,17 +1468,37 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
} }
size_t cols = input.size(input.ndim() - 1); size_t cols = input.size(input.ndim() - 1);
// Stochastic rounding
// When both rowwise and columnwise quantization are used with RHT,
// we need separate RNG states for each to ensure they use different random numbers.
TensorWrapper te_rng_state; TensorWrapper te_rng_state;
TensorWrapper te_rng_state_columnwise;
QuantizationConfigWrapper quant_config_columnwise;
const bool need_separate_columnwise_rng =
this->stochastic_rounding && this->with_rht && this->columnwise_usage;
if (this->stochastic_rounding) { if (this->stochastic_rounding) {
const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
// Generate RNG state for rowwise quantization
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto rng_state = torch::empty({2}, opts); auto rng_state = torch::empty({2}, opts);
philox_unpack(philox_args, static_cast<int64_t*>(rng_state.data_ptr())); philox_unpack(philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
te_rng_state = makeTransformerEngineTensor(rng_state); te_rng_state = makeTransformerEngineTensor(rng_state);
quant_config.set_rng_state(te_rng_state.data()); quant_config.set_rng_state(te_rng_state.data());
// Generate separate RNG state for columnwise quantization
if (need_separate_columnwise_rng) {
at::PhiloxCudaState philox_args_columnwise = init_philox_state(gen, rng_elts_per_thread);
auto rng_state_columnwise = torch::empty({2}, opts);
philox_unpack(philox_args_columnwise, static_cast<int64_t*>(rng_state_columnwise.data_ptr()));
te_rng_state_columnwise = makeTransformerEngineTensor(rng_state_columnwise);
quant_config_columnwise.set_stochastic_rounding(true);
quant_config_columnwise.set_rng_state(te_rng_state_columnwise.data());
}
} }
// Restriction for the RHT cast fusion kernel. // Restriction for the RHT cast fusion kernel.
...@@ -1605,6 +1625,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou ...@@ -1605,6 +1625,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
static_cast<DType>(out_columnwise_amax.dtype), static_cast<DType>(out_columnwise_amax.dtype),
out_columnwise_amax.shape); out_columnwise_amax.shape);
// Use separate RNG state for columnwise to ensure different random numbers than rowwise
auto& columnwise_quant_config =
need_separate_columnwise_rng ? quant_config_columnwise : quant_config;
if (!eligible_for_rht_cast_fusion) { if (!eligible_for_rht_cast_fusion) {
// Invoking fallback RHT kernel. // Invoking fallback RHT kernel.
...@@ -1629,7 +1653,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou ...@@ -1629,7 +1653,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
// Quantize kernel will treat everything as rowwise input/output, which is // Quantize kernel will treat everything as rowwise input/output, which is
// intended. // intended.
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config, stream); nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), columnwise_quant_config,
stream);
}); });
} else { } else {
// RHT cast fusion kernel. // RHT cast fusion kernel.
...@@ -1637,8 +1662,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou ...@@ -1637,8 +1662,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
"RHT matrix is not set"); "RHT matrix is not set");
auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix);
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_hadamard_transform_cast_fusion_columnwise( nvte_hadamard_transform_cast_fusion_columnwise(input.data(), out_transpose.data(),
input.data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config, stream); rht_matrix_nvte.data(),
columnwise_quant_config, stream);
}); });
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment