Commit efd41464 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

fix

parent c7913947
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -375,7 +375,7 @@ void add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, I8, F16, PassThrough, PassThrough, PassThrough>>>&
DeviceGemm<Row, Col, Row, F16, I8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
......@@ -638,6 +638,7 @@ struct DeviceOperationInstanceFactory<
{
add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_nk_mn_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
......
......@@ -16,7 +16,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t
using I8 = int8_t;
using F16 = ck::half_t;
using F32 = float;
......
......@@ -26,3 +26,7 @@ add_test_executable(test_gemm_int8 gemm_int8.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_int8 PRIVATE utility device_gemm_instance)
endif()
add_test_executable(test_gemm_fp16_int8 gemm_fp16_int8.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_fp16_int8 PRIVATE utility device_gemm_instance)
endif()
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using ADataType = ck::half_t;
using BDataType = int8_t;
using CDataType = ck::half_t;
using AccDataType = float;
#include "run_gemm_test.inc"
int main() { return run_gemm_test(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment