"tests/vscode:/vscode.git/clone" did not exist on "4c6fd258808ed42fc98a94f3a849f5fc9efebc20"
Commit 7581e4cb authored by 王敏's avatar 王敏
Browse files

[fix]修复0.7.2版本benchmark_moe因新增rocm参数报错问题

parent 249fca2a
...@@ -183,7 +183,8 @@ def benchmark_config( ...@@ -183,7 +183,8 @@ def benchmark_config(
def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False): def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
block_mn_range = [16, 32, 64, 128, 256] block_m_range = [16, 32, 64, 128, 256]
block_n_range = [32, 64, 128, 256]
block_k_range = [16, 32, 64, 128, 256] block_k_range = [16, 32, 64, 128, 256]
if not use_fp16: if not use_fp16:
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
...@@ -195,8 +196,8 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False): ...@@ -195,8 +196,8 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
kpack_range = [1, 2] if use_fp16 else [] kpack_range = [1, 2] if use_fp16 else []
param_ranges = { param_ranges = {
"BLOCK_SIZE_M": block_mn_range, "BLOCK_SIZE_M": block_m_range,
"BLOCK_SIZE_N": block_mn_range, "BLOCK_SIZE_N": block_n_range,
"BLOCK_SIZE_K": block_k_range, "BLOCK_SIZE_K": block_k_range,
"GROUP_SIZE_M": group_m_range, "GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range, "num_warps": num_warps_range,
...@@ -204,11 +205,12 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False): ...@@ -204,11 +205,12 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
"waves_per_eu": waves_per_eu_range, "waves_per_eu": waves_per_eu_range,
} }
if nn_moe: if nn_moe:
param_ranges["num_ldmatrixes"] = 1 param_ranges["num_ldmatrixes"] = [1]
if use_fp16: # DCU currently does not support the following parameters
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range # if use_fp16:
param_ranges["kpack"] = kpack_range # param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
# param_ranges["kpack"] = kpack_range
return param_ranges return param_ranges
...@@ -277,10 +279,11 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True): ...@@ -277,10 +279,11 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
num_warps = config.get("num_warps") num_warps = config.get("num_warps")
if is_fp16: # DCU currently does not support matrix_instr_nonkdim param
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") # if is_fp16:
if matrix_instr_nonkdim > mfma: # matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
continue # if matrix_instr_nonkdim > mfma:
# continue
if mfma == 4 and BLOCK_SIZE_K < 64: if mfma == 4 and BLOCK_SIZE_K < 64:
continue continue
# some layouts could not work properly in case # some layouts could not work properly in case
...@@ -289,16 +292,18 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True): ...@@ -289,16 +292,18 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
continue continue
SPLIT_K = config.get("SPLIT_K", 1) SPLIT_K = config.get("SPLIT_K", 1)
GROUP_M = config.get("GROUP_SIZE_M") GROUP_M = config.get("GROUP_SIZE_M")
if is_fp16:
if (matrix_instr_nonkdim > BLOCK_SIZE_M # DCU currently does not support matrix_instr_nonkdim param
or matrix_instr_nonkdim > BLOCK_SIZE_N): # if is_fp16:
continue # if (matrix_instr_nonkdim > BLOCK_SIZE_M
if (matrix_instr_nonkdim >= M # or matrix_instr_nonkdim > BLOCK_SIZE_N):
and matrix_instr_nonkdim != BLOCK_SIZE_M): # continue
continue # if (matrix_instr_nonkdim >= M
if (matrix_instr_nonkdim >= N # and matrix_instr_nonkdim != BLOCK_SIZE_M):
and matrix_instr_nonkdim != BLOCK_SIZE_N): # continue
continue # if (matrix_instr_nonkdim >= N
# and matrix_instr_nonkdim != BLOCK_SIZE_N):
# continue
# Skip BLOCK_SIZE that is too large compare to M/N # Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough # unless BLOCK_SIZE is already small enough
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
...@@ -452,7 +457,6 @@ class BenchmarkWorker: ...@@ -452,7 +457,6 @@ class BenchmarkWorker:
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
if "num_ldmatrixes" not in config:
return { return {
"BLOCK_SIZE_M": "BLOCK_SIZE_M":
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_M"],
...@@ -467,31 +471,8 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: ...@@ -467,31 +471,8 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
"num_stages": "num_stages":
config["num_stages"], config["num_stages"],
**({ **({
"waves_per_eu": config["waves_per_eu"] "num_ldmatrixes": config["num_ldmatrixes"]
} if "waves_per_eu" in config else {}), } if "num_ldmatrixes" in config else {}),
**({
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
} if "matrix_instr_nonkdim" in config else {}),
**({
"kpack": config["kpack"]
} if "kpack" in config else {}),
}
else:
return {
"BLOCK_SIZE_M":
config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N":
config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K":
config["BLOCK_SIZE_K"],
"GROUP_SIZE_M":
config["GROUP_SIZE_M"],
"num_warps":
config["num_warps"],
"num_stages":
config["num_stages"],
"num_ldmatrixes":
config["num_ldmatrixes"],
**({ **({
"waves_per_eu": config["waves_per_eu"] "waves_per_eu": config["waves_per_eu"]
} if "waves_per_eu" in config else {}), } if "waves_per_eu" in config else {}),
...@@ -643,7 +624,7 @@ if __name__ == "__main__": ...@@ -643,7 +624,7 @@ if __name__ == "__main__":
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
parser.add_argument("--nn-moe", action='store_true', default=False) parser.add_argument("--nn-moe", action='store_true', default=False)
parser.add_argument("--trust-remote-code", action="store_true") parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--moe-ep-size", type=int, default=1) parser.add_argument("--moe-ep-size", "-ep", type=int, default=1)
parser.add_argument("--num-gpus", type=int, default=1) parser.add_argument("--num-gpus", type=int, default=1)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment