"...resnet50_tensorflow.git" did not exist on "dfcc691caeed5e8633a2b87b3ce2bbb851fd57c0"
Commit 1a7ce816 authored by fsx950223's avatar fsx950223
Browse files

format example

parent 478df149
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
...@@ -20,6 +19,7 @@ ...@@ -20,6 +19,7 @@
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp"
// clang-format off
using EmbType = ck::half_t; using EmbType = ck::half_t;
using IndexType = int64_t; using IndexType = int64_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
...@@ -126,19 +126,20 @@ int main() ...@@ -126,19 +126,20 @@ int main()
beta_dev.ToDevice(beta.mData.data()); beta_dev.ToDevice(beta.mData.data());
auto device_instance = typename emb_kernel<EmbType, current_dim>::kernel_type{}; auto device_instance = typename emb_kernel<EmbType, current_dim>::kernel_type{};
auto argument_ptr = device_instance.MakeArgumentPointer(out_dev.GetDeviceBuffer(), auto argument_ptr = device_instance.MakeArgumentPointer(
{ck::type_convert<EmbType*>(emb_a_dev.GetDeviceBuffer()), out_dev.GetDeviceBuffer(),
ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()), {ck::type_convert<EmbType*>(emb_a_dev.GetDeviceBuffer()),
ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())}, ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()),
{ck::type_convert<IndexType*>(index_a_dev.GetDeviceBuffer()), ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())},
ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()), {ck::type_convert<IndexType*>(index_a_dev.GetDeviceBuffer()),
ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())}, ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()),
gamma_dev.GetDeviceBuffer(), ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())},
beta_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
current_dim, beta_dev.GetDeviceBuffer(),
index_length, current_dim,
epsilon, index_length,
EmbElementwiseOperation{}); epsilon,
EmbElementwiseOperation{});
std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString()
<< std::endl << std::endl
<< std::flush; << std::flush;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment