Commit 24672339 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Merge branch 'feature/integrage-karg-simplification-pr' into feature/test

parents f2c5ca5a 853e797e
...@@ -51,7 +51,8 @@ __global__ void ...@@ -51,7 +51,8 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__)) defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__))
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
...@@ -553,7 +554,8 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout, ...@@ -553,7 +554,8 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
{ {
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" || if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" ||
ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx940") ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
......
...@@ -1027,7 +1027,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1027,7 +1027,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")) if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102"))
{ {
return false; return false;
} }
......
...@@ -40,7 +40,8 @@ __global__ void ...@@ -40,7 +40,8 @@ __global__ void
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__)) defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__) || defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
......
...@@ -28,6 +28,7 @@ template <typename InDataType, ...@@ -28,6 +28,7 @@ template <typename InDataType,
typename AccElementwiseOperation, typename AccElementwiseOperation,
bool PropagateNan, bool PropagateNan,
bool OutputIndex, bool OutputIndex,
bool TransformIndexKtoGlobal,
bool HaveIndexInputIfOutputIndex, bool HaveIndexInputIfOutputIndex,
index_t BlockSize, index_t BlockSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
...@@ -260,6 +261,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType, ...@@ -260,6 +261,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
const auto kernel = kernel_reduce_threadwise<GridwiseReduce, const auto kernel = kernel_reduce_threadwise<GridwiseReduce,
OutputIndex, OutputIndex,
TransformIndexKtoGlobal,
HaveIndexInput, HaveIndexInput,
InDataType, InDataType,
OutDataType, OutDataType,
......
...@@ -15,6 +15,7 @@ namespace ck { ...@@ -15,6 +15,7 @@ namespace ck {
template <typename GridwiseReduction, template <typename GridwiseReduction,
bool OutputIndex, bool OutputIndex,
bool TransformIndexKtoGlobal,
bool HaveIndexInput, bool HaveIndexInput,
typename InDataType, typename InDataType,
typename OutDataType, typename OutDataType,
...@@ -48,7 +49,8 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, ...@@ -48,7 +49,8 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
} }
else else
{ {
GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k, GridwiseReduction::template RunWithIndex<TransformIndexKtoGlobal, HaveIndexInput>(
in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
...@@ -232,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -232,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
}; };
template <bool HaveIndexInput> template <bool TransformIndexKtoGlobal, bool HaveIndexInput>
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m, const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
...@@ -390,6 +392,18 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -390,6 +392,18 @@ struct GridwiseReduction_mk_to_m_threadwise
indexStart += KThreadSliceSize; indexStart += KThreadSliceSize;
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength); } while(reducedLength < toReduceLength);
if constexpr(TransformIndexKtoGlobal)
{
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
const auto coord = make_tensor_coordinate(
in_grid_desc_m_k,
make_multi_index(thread_global_1d_id * MThreadSliceSize + I,
accu_index_buf(I)));
accu_index_buf(I) = coord.GetOffset();
});
}
}; };
// for indiced operation, acc_elementwise_op shoud do nothing // for indiced operation, acc_elementwise_op shoud do nothing
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_LLVM_INTRINSIC_HPP
#define CK_AMD_LLVM_INTRINSIC_HPP
#include "data_type.hpp"
namespace ck {
__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane");
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_PRINT_HPP
#define CK_PRINT_HPP
#include "array.hpp"
#include "statically_indexed_array.hpp"
#include "container_helper.hpp"
#include "sequence.hpp"
namespace ck {
template <typename T>
__host__ __device__ void print_array(const char* s, T a)
{
constexpr index_t nsize = a.Size();
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
printf("}\n");
}
} // namespace ck
#endif
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#pragma once #pragma once
#include <cstdlib> #include <vector>
#include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#pragma once #pragma once
#include <cstdlib> #include <vector>
#include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment