Unverified Commit e60c5aea authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #36 from ROCm/lwpck-1299

Initial MI350 enablement.
parents 29dcb956 63df00cd
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
......
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0) if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
......
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0) if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
......
...@@ -45,7 +45,7 @@ ...@@ -45,7 +45,7 @@
#endif #endif
// define general macros for various architectures // define general macros for various architectures
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__ #define __gfx94__
#endif #endif
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) #if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
......
...@@ -55,14 +55,15 @@ inline bool is_xdl_supported() ...@@ -55,14 +55,15 @@ inline bool is_xdl_supported()
{ {
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
} }
inline bool is_lds_direct_load_supported() inline bool is_lds_direct_load_supported()
{ {
// Check if direct loads from global memory to LDS are supported. // Check if direct loads from global memory to LDS are supported.
return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942" ||
ck::get_device_name() == "gfx950";
} }
inline bool is_navi1_supported() inline bool is_navi1_supported()
......
...@@ -602,9 +602,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -602,9 +602,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return false; return false;
} }
if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" && if(!ck::is_lds_direct_load_supported() && std::is_same<ADataType, double>::value)
ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" &&
std::is_same<ADataType, double>::value)
{ {
return false; return false;
} }
......
...@@ -294,7 +294,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -294,7 +294,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
if((ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || if((ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942")) ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{ {
return false; return false;
} }
......
...@@ -39,7 +39,7 @@ __global__ void ...@@ -39,7 +39,7 @@ __global__ void
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace ck { namespace ck {
// Define the common macro for MI300 models // Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__ #define __gfx94__
#endif #endif
......
...@@ -189,7 +189,7 @@ struct vector_type<T, 1> ...@@ -189,7 +189,7 @@ struct vector_type<T, 1>
} }
}; };
int static err = 0; __device__ int static err = 0;
template <typename T> template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2>
{ {
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace ck { namespace ck {
// Define the common macro for MI300 models // Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__ #define __gfx94__
#endif #endif
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
...@@ -6,4 +6,4 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -6,4 +6,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance) target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
...@@ -10,4 +10,4 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -10,4 +10,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1) set(target 1)
endif() endif()
endif() endif()
endforeach() endforeach()
\ No newline at end of file
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
...@@ -10,4 +10,4 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -10,4 +10,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1) set(target 1)
endif() endif()
endif() endif()
endforeach() endforeach()
\ No newline at end of file
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
...@@ -26,4 +26,4 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -26,4 +26,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif() endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
...@@ -6,4 +6,4 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -6,4 +6,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance) target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment