Commit 41b54611 authored by root's avatar root
Browse files

Merge branch 'develop' into gemm_bf16_sk_muozturk

parents 8a34c640 aeb7c91f
......@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
......@@ -390,7 +390,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
......
......@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
......@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
......@@ -518,7 +518,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
......@@ -575,7 +576,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
......
......@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
......@@ -504,7 +504,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
......
......@@ -64,9 +64,9 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
# Do not build mha instances if gfx94 targets are not on the target list
# Do not build mha instances if gfx94 or gfx90a targets are not on the target list
foreach(source IN LISTS ARGN)
if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "mha")
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha")
message("removing mha instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
......@@ -85,7 +85,7 @@ function(add_instance_library INSTANCE_NAME)
elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "mha")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
endif()
set(offload_targets)
foreach(target IN LISTS INST_TARGETS)
......@@ -320,8 +320,7 @@ if(CK_DEVICE_CONV_INSTANCES)
endif()
if(CK_DEVICE_MHA_INSTANCES)
set(gpu_list ${INST_TARGETS})
list(FILTER gpu_list INCLUDE REGEX "^gfx94")
if(gpu_list)
if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a")
add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES})
add_library(composablekernels::device_mha_operations ALIAS device_mha_operations)
target_compile_features(device_mha_operations PUBLIC)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment