Unverified Commit 1ec33ae1 authored by Rachit Garg's avatar Rachit Garg Committed by GitHub
Browse files

Rachitg/dp carveout (#722)



* fix the perf regression because of constant property polling of the device
Signed-off-by: default avatarRachit Garg <rachitg@nvidia.com>

* Fix lint error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarRachit Garg <rachitg@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarRachit Garg <rachitg@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent ffa24475
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include "common/util/system.h" #include "common/util/system.h"
#include "common/util/cuda_runtime.h"
namespace { namespace {
transformer_engine::DType reverse_map_dtype(int64_t dtype) { transformer_engine::DType reverse_map_dtype(int64_t dtype) {
...@@ -320,10 +321,9 @@ at::Tensor te_gemm_ts(at::Tensor A, ...@@ -320,10 +321,9 @@ at::Tensor te_gemm_ts(at::Tensor A,
// Set an external SM Margin to all the GEMMs. // Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs // This comes in handy when DP is overlapped with GEMMs
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0); const int sm_count = transformer_engine::cuda::sm_count();
int num_math_sms = prop.multiProcessorCount \ int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
- transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0);
if (A_scale_inverse.numel()) if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor]; A_scale_inverse = A_scale_inverse[A_fp8_tensor];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment