"include/vscode:/vscode.git/clone" did not exist on "85bddfbcc64daeca26727b179779ac727d26ddf7"
Commit 76f2b6cd authored by danyao12's avatar danyao12
Browse files

merge develop to attn-train-develop-qloop

parents 9b4c780a 1ee99dca
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
......
...@@ -71,13 +71,13 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -71,13 +71,13 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using InputDataType = BF16; using InputDataType = F16;
using OutputDataType = F32; using OutputDataType = F16;
using GemmDataType = BF16; using GemmDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using ZDataType = INT32; // U16 using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -88,7 +88,7 @@ static constexpr ck::index_t NumDimK = 1; ...@@ -88,7 +88,7 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4 // When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8 // When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4; static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
...@@ -103,7 +103,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia ...@@ -103,7 +103,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true; static constexpr bool Deterministic = false;
// DIM should be a multiple of 8. // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1 1st template.
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
...@@ -217,8 +217,8 @@ int mean_meansquare_dual_reduce_test(size_t n, ...@@ -217,8 +217,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
size_t invariant_total_length = n; size_t invariant_total_length = n;
size_t reduce_total_length = h * w * c; size_t reduce_total_length = h * w * c;
const AccDataType alpha = ck::type_convert<AccDataType>(1.0f); const double alpha = 1.0f;
const AccDataType beta = ck::type_convert<AccDataType>(0.0f); const double beta = 0.0f;
std::size_t num_thread = 1; std::size_t num_thread = 1;
...@@ -267,8 +267,8 @@ int mean_meansquare_dual_reduce_test(size_t n, ...@@ -267,8 +267,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
i_outLengths, i_outLengths,
{i_outStrides, i_outStrides}, {i_outStrides, i_outStrides},
reduceDims, reduceDims,
{&alpha, &alpha}, {alpha, alpha},
{&beta, &beta}, {beta, beta},
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
{mean_dev.GetDeviceBuffer(), meansquare_dev.GetDeviceBuffer()}, {mean_dev.GetDeviceBuffer(), meansquare_dev.GetDeviceBuffer()},
ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}), ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}),
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
......
add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp) add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp)
add_example_executable(example_batchnorm_forward_training_obsolete batchnorm_forward_training_nhwc_obsolete.cpp)
add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp) add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp)
add_example_executable(example_batchnorm_backward batchnorm_backward_nhwc.cpp) add_example_executable(example_batchnorm_backward batchnorm_backward_nhwc.cpp)
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <limits> #include <limits>
#include <iostream> #include <iostream>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment