Commit 8dd7156d authored by ltqin's avatar ltqin
Browse files

Merge branch 'mha-train-develop' into attn-train-develop-qloop-mask

parents d5f629e7 b5a3ea2d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = float;
using WeiDataType = float;
using OutDataType = float;
using InLayout = ck::tensor_layout::convolution::NDHWC;
using WeiLayout = ck::tensor_layout::convolution::KZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWK;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t N = 64;
static constexpr ck::index_t K = 128;
static constexpr ck::index_t C = 64;
static constexpr ck::index_t Z = 3;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Di = 28;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 28;
int main()
{
return run_conv_bwd_data<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(N, K, C, {Di, Hi, Wi}, {Z, Y, X}, {Do, Ho, Wo})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
add_executable(client_gemm_add_multiply gemm_add_multiply.cpp)
target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_operations)
\ No newline at end of file
This diff is collapsed.
add_executable(client_reduce_nhwc_c reduce_nhwc_c.cpp)
target_link_libraries(client_reduce_nhwc_c PRIVATE composable_kernel::device_operations)
This diff is collapsed.
add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp)
add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp)
target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations)
target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp)
target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_operations)
\ No newline at end of file
add_executable(client_groupnorm_swish groupnorm_swish.cpp)
target_link_libraries(client_groupnorm_swish PRIVATE composable_kernel::device_operations)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment