Unverified Commit 62d1b2bd authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

Introduce nvte_memset to provide a fill kernel that is faster than...


Introduce nvte_memset to provide a fill kernel that is faster than cudaMemsetAsync for small sizes (#1716)

* nvte_memset fills single float value
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Support larger sizes than a single value and add tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 7186df4f
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <type_traits>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
using namespace transformer_engine;
class MemsetTestSuite : public ::testing::TestWithParam<std::tuple<int,
size_t>> {};
TEST_P(MemsetTestSuite, TestMemset) {
using namespace transformer_engine;
using namespace test;
int value = std::get<0>(GetParam());
size_t size_in_bytes = std::get<1>(GetParam());
std::vector<uint8_t> h_buffer{};
h_buffer.resize(size_in_bytes);
for (size_t i = 0; i < size_in_bytes; ++i) {
h_buffer[i] = value + 1; // Initialize host buffer to a different value than memset value to verify memset is working correctly
}
char* d_ptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_ptr, size_in_bytes));
NVTE_CHECK_CUDA(cudaMemcpy(d_ptr, h_buffer.data(), size_in_bytes, cudaMemcpyHostToDevice));
nvte_memset(d_ptr, value, size_in_bytes, 0 /* stream */);
NVTE_CHECK_CUDA(cudaMemcpy(
h_buffer.data(), d_ptr, size_in_bytes, cudaMemcpyDeviceToHost));
NVTE_CHECK_CUDA(cudaFree(d_ptr));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
for (size_t i = 0; i < size_in_bytes; ++i) {
EXPECT_EQ(h_buffer[i], static_cast<uint8_t>(value))
<< "Mismatch at index " << i << ": expected " << static_cast<int>(value)
<< ", got " << static_cast<int>(h_buffer[i]);
}
}
namespace {
std::vector<size_t> memset_test_sizes = {
1,
4,
9,
16,
128,
4096,
4097,
8192,
};
} // namespace
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MemsetTestSuite,
::testing::Combine(
::testing::Values(0, 6),
::testing::ValuesIn(memset_test_sizes)),
[](const testing::TestParamInfo<MemsetTestSuite::ParamType>& info) {
std::string name = std::to_string(std::get<0>(info.param)) + "X" +
std::to_string(std::get<1>(info.param));
return name;
});
......@@ -35,6 +35,65 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
}
}
namespace {
constexpr size_t kThreadsPerBlock = 256;
template <typename TVectorized>
__global__ void __launch_bounds__(kThreadsPerBlock)
memset_kernel(void *__restrict__ ptr, int value, size_t size_in_bytes) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx * sizeof(TVectorized) >= size_in_bytes) {
return; // Out of bounds
}
if ((idx + 1) * sizeof(TVectorized) > size_in_bytes) {
// If the buffer size is not an even multiple of the vectorization, manually set the remaining bytes unvectorized.
size_t remaining_bytes = size_in_bytes - idx * sizeof(TVectorized);
memset(reinterpret_cast<uint8_t *>(ptr) + idx * sizeof(TVectorized), value, remaining_bytes);
return;
}
union {
TVectorized value;
uint8_t data[sizeof(TVectorized)];
} data;
for (size_t i = 0; i < sizeof(TVectorized); ++i) {
data.data[i] = static_cast<uint8_t>(value);
}
reinterpret_cast<TVectorized *>(ptr)[idx] = data.value;
}
} // namespace
#define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \
if (size_in_bytes >= sizeof(vectorizedType) && \
reinterpret_cast<size_t>(ptr) % sizeof(vectorizedType) == 0) { \
size_t numBlocks = DIVUP(size_in_bytes, kThreadsPerBlock * sizeof(vectorizedType)); \
dim3 grid(numBlocks, 1, 1); \
memset_kernel<vectorizedType> \
<<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes); \
return; \
}
extern "C" {
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream) {
NVTE_API_CALL(nvte_memset);
NVTE_CHECK(ptr != nullptr, "Pointer for memset must be allocated.");
if (size_in_bytes > 4096) {
// Use cudaMemsetAsync for larger sizes.
cudaMemsetAsync(ptr, value, size_in_bytes, stream);
return;
}
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float4, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float2, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, uint8_t, stream);
}
} // extern "C"
void checkCuDriverContext(CUstream stream) {
CUcontext ctx;
const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx);
......
......@@ -338,6 +338,17 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config);
*/
int nvte_is_non_tn_fp8_gemm_supported();
/*! \brief Performs a memset of the data at the given pointer and size in bytes.
*
* \param[in] ptr Pointer to the memory to be set.
* \param[in] value Value to set the memory to.
* \param[in] size_in_bytes Size of the memory in bytes.
* \param[in] stream CUDA stream to use for the operation.
*
* This function calls a fill kernel for small sizes and calls cudaMemsetAsync for larger sizes.
*/
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -53,7 +53,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
......@@ -266,7 +266,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
......
......@@ -123,7 +123,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax, 0, sizeof(float), stream);
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
......
......@@ -122,7 +122,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax, 0, sizeof(float), stream);
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment