Commit 08ab9cfa authored by aska-0096's avatar aska-0096
Browse files

fix a typo of name

parent 8a6e65a3
...@@ -182,9 +182,9 @@ int run(int argc, char* argv[]) ...@@ -182,9 +182,9 @@ int run(int argc, char* argv[])
printf("Verification: %s\n", do_verification ? "ON" : "OFF"); printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void { ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{}); const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>; using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{}; auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
......
...@@ -185,9 +185,9 @@ int run(int argc, char* argv[]) ...@@ -185,9 +185,9 @@ int run(int argc, char* argv[])
printf("Verification: %s\n", do_verification ? "ON" : "OFF"); printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void { ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{}); const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>; using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{}; auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
......
...@@ -185,9 +185,9 @@ int run(int argc, char* argv[]) ...@@ -185,9 +185,9 @@ int run(int argc, char* argv[])
printf("Verification: %s\n", do_verification ? "ON" : "OFF"); printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void { ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{}); const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>; using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{}; auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
......
...@@ -215,9 +215,9 @@ int run(int argc, char* argv[]) ...@@ -215,9 +215,9 @@ int run(int argc, char* argv[])
printf("Verification: %s\n", do_verification ? "ON" : "OFF"); printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void { ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{}); const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>; using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{}; auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeSelfAttnInvoker(); auto invoker = gemm.MakeSelfAttnInvoker();
auto argument = auto argument =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment