Unverified Commit eef7ab50 authored by Zhean Xu's avatar Zhean Xu Committed by GitHub
Browse files

feat: support cluster size 2 (#283)


Co-authored-by: default avatarZhean Xu <xza@deepseek.com>
parent e6d61fc6
...@@ -7,11 +7,15 @@ ...@@ -7,11 +7,15 @@
#ifndef DISABLE_SM90_FEATURES #ifndef DISABLE_SM90_FEATURES
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \ #define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \ cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \
cudaLaunchAttribute attr[1]; \ cudaLaunchAttribute attr[2]; \
attr[0].id = cudaLaunchAttributeCooperative; \ attr[0].id = cudaLaunchAttributeCooperative; \
attr[0].val.cooperative = 1; \ attr[0].val.cooperative = 1; \
attr[1].id = cudaLaunchAttributeClusterDimension; \
attr[1].val.clusterDim.x = (num_sms % 2 == 0 ? 2 : 1); \
attr[1].val.clusterDim.y = 1; \
attr[1].val.clusterDim.z = 1; \
cfg.attrs = attr; \ cfg.attrs = attr; \
cfg.numAttrs = 1 cfg.numAttrs = 2
#else #else
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ #define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
int __num_sms = (sms); \ int __num_sms = (sms); \
...@@ -69,13 +73,13 @@ cfg.dynamicSmemBytes = smem_size; ...@@ -69,13 +73,13 @@ cfg.dynamicSmemBytes = smem_size;
case 2: case_macro(dtype, 2); \ case 2: case_macro(dtype, 2); \
case 4: case_macro(dtype, 4); \ case 4: case_macro(dtype, 4); \
case 8: case_macro(dtype, 8); \ case 8: case_macro(dtype, 8); \
default: EP_HOST_ASSERT(false && "Unsupported ranks"); \ default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
} while (false) } while (false)
#define SWITCH_TYPES(case_macro) \ #define SWITCH_TYPES(case_macro) \
switch (type) { \ switch (type) { \
case CUDA_R_16BF: case_macro(nv_bfloat16); \ case CUDA_R_16BF: case_macro(nv_bfloat16); \
default: EP_HOST_ASSERT(false && "Unsupported type"); \ default: EP_HOST_ASSERT(false and "Unsupported type"); \
} while (false) } while (false)
#define SWITCH_HIDDEN(case_macro) \ #define SWITCH_HIDDEN(case_macro) \
...@@ -86,5 +90,5 @@ cfg.dynamicSmemBytes = smem_size; ...@@ -86,5 +90,5 @@ cfg.dynamicSmemBytes = smem_size;
case 5120: case_macro(5120); \ case 5120: case_macro(5120); \
case 7168: case_macro(7168); \ case 7168: case_macro(7168); \
case 8192: case_macro(8192); \ case 8192: case_macro(8192); \
default: EP_HOST_ASSERT(false && "Unsupported hidden"); \ default: EP_HOST_ASSERT(false and "Unsupported hidden"); \
} while (false) } while (false)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment