Commit efd41464 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

fix

parent c7913947
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -375,7 +375,7 @@ void add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_kn_mn_instances( ...@@ -375,7 +375,7 @@ void add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, I8, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Col, Row, F16, I8, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif #endif
...@@ -634,10 +634,11 @@ struct DeviceOperationInstanceFactory< ...@@ -634,10 +634,11 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_kn_mn_instances(op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
{ {
add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_nk_mn_instances(op_ptrs);
} }
}
#endif #endif
return op_ptrs; return op_ptrs;
} }
......
...@@ -16,7 +16,7 @@ namespace tensor_operation { ...@@ -16,7 +16,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using I8 = int8_t using I8 = int8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
......
...@@ -25,4 +25,8 @@ endif() ...@@ -25,4 +25,8 @@ endif()
add_test_executable(test_gemm_int8 gemm_int8.cpp) add_test_executable(test_gemm_int8 gemm_int8.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_gemm_int8 PRIVATE utility device_gemm_instance) target_link_libraries(test_gemm_int8 PRIVATE utility device_gemm_instance)
endif()
add_test_executable(test_gemm_fp16_int8 gemm_fp16_int8.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_fp16_int8 PRIVATE utility device_gemm_instance)
endif() endif()
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using ADataType = ck::half_t;
using BDataType = int8_t;
using CDataType = ck::half_t;
using AccDataType = float;
#include "run_gemm_test.inc"
int main() { return run_gemm_test(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment