Unverified Commit c582f6be authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Common] Reduce CUDA driver calls (#2067)



* reduce driver calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* reduce driver calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* adjust tests to capture this
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 44fbe9e6
...@@ -122,10 +122,12 @@ class _Sequential(torch.nn.Sequential): ...@@ -122,10 +122,12 @@ class _Sequential(torch.nn.Sequential):
# Supported modules # Supported modules
_test_cuda_graphs_modules: List[str] = [ _test_cuda_graphs_modules: List[str] = [
# Put linear first to test the case where the cuda context might not be set in
# creating TMA descriptor for MXFP8 quantization.
"linear",
"transformer", "transformer",
"layernorm_mlp", "layernorm_mlp",
"layernorm_linear", "layernorm_linear",
"linear",
"mha", "mha",
"linear_op", "linear_op",
] ]
...@@ -308,9 +310,11 @@ def test_make_graphed_callables( ...@@ -308,9 +310,11 @@ def test_make_graphed_callables(
fp8_weight_caching=fp8_weight_caching, fp8_weight_caching=fp8_weight_caching,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
) )
outputs = _test_cuda_graphs(graph_mode="none", **kwargs) # Put graphed callables first to test the case where the cuda context might not be set in
# creating TMA descriptor for MXFP8 quantization.
graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs) graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
# Check that results match. # Check that results match.
assert_all_equal(outputs, graph_outputs_mode1) assert_all_equal(outputs, graph_outputs_mode1)
......
...@@ -95,6 +95,9 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream ...@@ -95,6 +95,9 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream
} // extern "C" } // extern "C"
void checkCuDriverContext(CUstream stream) { void checkCuDriverContext(CUstream stream) {
// Ensure the thread's "current" CUDA context is set.
cuda_driver::ensure_context_exists();
CUcontext ctx; CUcontext ctx;
const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx); const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx);
switch (driver_status) { switch (driver_status) {
...@@ -138,7 +141,6 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, ...@@ -138,7 +141,6 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems, const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_num_bits) { const uint32_t offset_elems, const size_t type_num_bits) {
cuda_driver::ensure_context_exists();
// Get a function pointer to the cuTensorMapEncodeTiled driver API // Get a function pointer to the cuTensorMapEncodeTiled driver API
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
......
...@@ -488,6 +488,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -488,6 +488,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
const bool return_transpose, const bool pow_2_scale, const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_square_blockwise); NVTE_API_CALL(quantize_transpose_square_blockwise);
checkCuDriverContext(stream);
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_rows = 1; size_t num_rows = 1;
......
...@@ -885,6 +885,8 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP ...@@ -885,6 +885,8 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
checkCuDriverContext(stream);
if (output->has_data()) { if (output->has_data()) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
} }
...@@ -964,6 +966,8 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP ...@@ -964,6 +966,8 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
checkCuDriverContext(stream);
const bool USE_ROWWISE_SCALING = output->has_data(); const bool USE_ROWWISE_SCALING = output->has_data();
const bool USE_COLWISE_SCALING = output->has_columnwise_data(); const bool USE_COLWISE_SCALING = output->has_columnwise_data();
...@@ -1206,7 +1210,6 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP ...@@ -1206,7 +1210,6 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
checkCuDriverContext(stream);
constexpr bool allow_empty = false; constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input"); CheckInputTensor(gated_input, "gated_input");
CheckOutputTensor(*output, "output", allow_empty); CheckOutputTensor(*output, "output", allow_empty);
......
...@@ -1006,9 +1006,10 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1006,9 +1006,10 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const Tensor *noop, // TODO (ksivamani) const Tensor *noop, // TODO (ksivamani)
Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
using namespace mxfp8_kernel; using namespace mxfp8_kernel;
checkCuDriverContext(stream);
bool use_rowwise_scaling = output->has_data(); bool use_rowwise_scaling = output->has_data();
bool use_colwise_scaling = output->has_columnwise_data(); bool use_colwise_scaling = output->has_columnwise_data();
checkCuDriverContext(stream);
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
......
...@@ -45,16 +45,19 @@ void *get_symbol(const char *symbol, int cuda_version) { ...@@ -45,16 +45,19 @@ void *get_symbol(const char *symbol, int cuda_version) {
} }
void ensure_context_exists() { void ensure_context_exists() {
CUcontext context; static thread_local bool need_check = []() {
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context); CUcontext context;
if (context == nullptr) { NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context);
// Add primary context to context stack if (context == nullptr) {
CUdevice device; // Add primary context to context stack
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device()); CUdevice device;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device); NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device());
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context); NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device);
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device); NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
} NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device);
}
return false;
}();
} }
} // namespace cuda_driver } // namespace cuda_driver
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment