Unverified Commit 496620a9 authored by vcherepanov-nv's avatar vcherepanov-nv Committed by GitHub
Browse files

Get rid of nvshmem dependency for cuBLASMp integration (#2661)



* Remove nvshmem usage
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Renamings
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NCCL dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Check for not yet allocated workspace
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Address greptile comments
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add a comment per greptile
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix a typo
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Display human-readable cuBLASMp error message
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cd098e42
...@@ -77,11 +77,6 @@ def setup_common_extension() -> CMakeExtension: ...@@ -77,11 +77,6 @@ def setup_common_extension() -> CMakeExtension:
f"nvidia-cublasmp-cu{cuda_version()[0]}" f"nvidia-cublasmp-cu{cuda_version()[0]}"
).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}") ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
f"nvidia-nvshmem-cu{cuda_version()[0]}"
).locate_file("nvidia/nvshmem")
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
print("CMAKE_FLAGS:", cmake_flags[-2:])
# Add custom CMake arguments from environment variable # Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
......
...@@ -287,20 +287,18 @@ endif() ...@@ -287,20 +287,18 @@ endif()
option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF)
if (NVTE_WITH_CUBLASMP) if (NVTE_WITH_CUBLASMP)
target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP)
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include)
find_library(CUBLASMP_LIB find_library(CUBLASMP_LIB
NAMES cublasmp libcublasmp NAMES cublasmp libcublasmp
PATHS ${CUBLASMP_DIR} PATHS ${CUBLASMP_DIR}
PATH_SUFFIXES lib PATH_SUFFIXES lib
REQUIRED) REQUIRED)
find_library(NVSHMEM_HOST_LIB find_library(NCCL_LIB
NAMES nvshmem_host libnvshmem_host.so.3 NAMES nccl libnccl
PATHS ${NVSHMEM_DIR}
PATH_SUFFIXES lib PATH_SUFFIXES lib
REQUIRED) REQUIRED)
target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB})
message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}")
message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}")
endif() endif()
# Hack to enable dynamic loading in cuDNN frontend # Hack to enable dynamic loading in cuDNN frontend
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include <cublasmp.h> #include <cublasmp.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <nvshmem.h>
#include <map> #include <map>
#include <memory> #include <memory>
...@@ -236,7 +235,7 @@ void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n ...@@ -236,7 +235,7 @@ void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n
ctx->grid_row_major.get(), ctx->d_desc.get())); ctx->grid_row_major.get(), ctx->d_desc.get()));
const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE; const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE;
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue)); sizeof epilogue));
} }
...@@ -273,46 +272,46 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo ...@@ -273,46 +272,46 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N; const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N;
const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N; const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a,
sizeof trans_a)); sizeof trans_a));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b,
sizeof trans_b)); sizeof trans_b));
cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo); cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr,
sizeof algo_attr)); sizeof algo_attr));
const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32; const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32;
if (is_fp8_dtype(a->dtype())) { if (is_fp8_dtype(a->dtype())) {
NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype"); NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode,
sizeof scale_mode)); sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER,
&a->scale_inv.dptr, sizeof(void*))); &a->scale_inv.dptr, sizeof(void*)));
} }
if (is_fp8_dtype(b->dtype())) { if (is_fp8_dtype(b->dtype())) {
NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype"); NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode,
sizeof scale_mode)); sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER,
&b->scale_inv.dptr, sizeof(void*))); &b->scale_inv.dptr, sizeof(void*)));
} }
if (is_fp8_dtype(d->dtype())) { if (is_fp8_dtype(d->dtype())) {
NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype"); NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode,
sizeof scale_mode)); sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER,
&d->scale.dptr, sizeof(void*))); &d->scale.dptr, sizeof(void*)));
if (d->amax.dptr) { if (d->amax.dptr) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER,
&d->amax.dptr, sizeof(void*))); &d->amax.dptr, sizeof(void*)));
} }
...@@ -321,7 +320,7 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo ...@@ -321,7 +320,7 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
// Might be set to ALLREDUCE before, need to OR with the new flags to set. // Might be set to ALLREDUCE before, need to OR with the new flags to set.
cublasMpMatmulEpilogue_t epilogue{}; cublasMpMatmulEpilogue_t epilogue{};
size_t size_read{}; size_t size_read{};
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorGetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue, &size_read)); sizeof epilogue, &size_read));
NVTE_CHECK(size_read == sizeof epilogue); NVTE_CHECK(size_read == sizeof epilogue);
...@@ -339,42 +338,42 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo ...@@ -339,42 +338,42 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad}); pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad});
it != flags_to_epilogue.end()) { it != flags_to_epilogue.end()) {
epilogue = static_cast<cublasMpMatmulEpilogue_t>(epilogue | it->second); epilogue = static_cast<cublasMpMatmulEpilogue_t>(epilogue | it->second);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue)); sizeof epilogue));
} }
if (bias && bias->data.dptr) { if (bias && bias->data.dptr) {
cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype); cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type,
sizeof bias_type)); sizeof bias_type));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr,
sizeof bias->data.dptr)); sizeof bias->data.dptr));
} }
if (pre_act_out && pre_act_out->data.dptr) { if (pre_act_out && pre_act_out->data.dptr) {
cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype); cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE,
&aux_type, sizeof aux_type)); &aux_type, sizeof aux_type));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER,
&pre_act_out->data.dptr, sizeof pre_act_out->data.dptr)); &pre_act_out->data.dptr, sizeof pre_act_out->data.dptr));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd,
sizeof ldd)); sizeof ldd));
if (is_fp8_dtype(pre_act_out->dtype())) { if (is_fp8_dtype(pre_act_out->dtype())) {
NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype"); NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE,
&scale_mode, sizeof scale_mode)); &scale_mode, sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER,
&pre_act_out->scale.dptr, sizeof(void*))); &pre_act_out->scale.dptr, sizeof(void*)));
if (pre_act_out->amax.dptr) { if (pre_act_out->amax.dptr) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER,
&pre_act_out->amax.dptr, sizeof(void*))); &pre_act_out->amax.dptr, sizeof(void*)));
} }
...@@ -382,12 +381,12 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo ...@@ -382,12 +381,12 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
} }
if (comm_sm_count) { if (comm_sm_count) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT, ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT,
&comm_sm_count, sizeof comm_sm_count)); &comm_sm_count, sizeof comm_sm_count));
} }
NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream)); NVTE_CHECK_CUBLASMP(cublasMpSetStream(ctx->cublas_mp.get(), main_stream));
size_t wrksp_size_device{}; size_t wrksp_size_device{};
size_t wrksp_size_host{}; size_t wrksp_size_host{};
...@@ -423,8 +422,14 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo ...@@ -423,8 +422,14 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
std::vector<uint8_t> workspace_host(wrksp_size_host); std::vector<uint8_t> workspace_host(wrksp_size_host);
if (ctx->workspace_size < wrksp_size_device) { if (ctx->workspace_size < wrksp_size_device) {
nvshmem_free(ctx->workspace); if (ctx->workspace) {
ctx->workspace = nvshmem_malloc(wrksp_size_device); NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace));
NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace));
}
NVTE_CHECK_CUBLASMP(
cublasMpMalloc(ctx->grid_col_major.get(), &ctx->workspace, wrksp_size_device));
NVTE_CHECK_CUBLASMP(
cublasMpBufferRegister(ctx->grid_row_major.get(), ctx->workspace, wrksp_size_device));
ctx->workspace_size = wrksp_size_device; ctx->workspace_size = wrksp_size_device;
} }
...@@ -473,7 +478,10 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank ...@@ -473,7 +478,10 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank
void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) {
NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); NVTE_API_CALL(nvte_comm_gemm_ctx_destroy);
nvshmemx_sync_all_on_stream(ctx->stream.get()); if (ctx->workspace) {
NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace));
NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace));
}
delete ctx; delete ctx;
} }
......
...@@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank ...@@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank
/*! \brief Destroy a comm-gemm context. /*! \brief Destroy a comm-gemm context.
* *
* \param[in] ctx Context to destroy. * \param[in] ctx Context to destroy.
*
* It's the caller's responsibility to synchronize all streams involved before calling this function.
*/ */
void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx); void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx);
......
...@@ -100,7 +100,7 @@ ...@@ -100,7 +100,7 @@
do { \ do { \
const cublasMpStatus_t status = (expr); \ const cublasMpStatus_t status = (expr); \
if (status != CUBLASMP_STATUS_SUCCESS) { \ if (status != CUBLASMP_STATUS_SUCCESS) { \
NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \ NVTE_ERROR("cuBLASMp Error: ", cublasMpGetStatusString(status)); \
} \ } \
} while (false) } while (false)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment