Commit 47cc9b7e authored by Astha Rai's avatar Astha Rai
Browse files

added compilation of shared library and multiple instances for gemm, cleaned up code design

parent adbefd90
#pragma once
#include "model-generated.h"
#include "model_interface.h"
#include "raii_wrapper.h"
#include <condition_variable>
#include <cstring>
#include <future>
#include <mutex>
#include <numeric>
#include <unordered_map>
namespace ait {
// ModelContainer inherits from this class; its implementation is
// generated at compilation time. Most of the ModelContainer
// logic does not need codegen; anything that does should be put
// into this class instead.
class ModelContainerBase {
public:
ModelContainerBase(
size_t num_inputs,
size_t num_outputs,
size_t num_unbound_constants,
size_t params_size,
AITemplateAllocator& allocator);
protected:
// The set of unbound constants/weights/parameters. These are constants which
// have no value at compile time and do not participate in constant folding.
// They must be set via SetConstant prior to inference.
std::unordered_map<std::string, size_t> unbound_constant_name_to_idx_;
// a single piece of memory for all constants
GPUPtr constants_;
// size of the containers below: # inputs + # outputs + # unbound constants.
size_t num_params_;
// These entries correspond to inputs/outputs/unbound constants in order;
// inputs first, then outputs, then constants.
std::vector<const char*> param_names_;
std::vector<std::vector<int64_t>> max_param_shapes_;
std::vector<AITemplateDtype> param_dtypes_;
// NB: technically these could be derived from both the max shape and
// the dytpe, but it's easier to just cache them.
std::vector<size_t> max_param_storage_bytes_;
std::vector<size_t> max_param_numel_;
};
// This creates a new ModelContainer; its implementation is also
// codegened (the parameters passed to the ctor are determined
// at compilation time)
class ModelContainer;
ModelContainer* CreateModelContainer(
size_t num_runtimes,
AITemplateAllocator& allocator);
// Each ModelContainer contains num_models Models. Inference runs
// can be started by invoking Run() with lists of pre-allocated
// input/output tensors. GetOutputMaximumShape() can be used to
// determine how much memory is required for each output.
//
// If there are N tensors marked with is_output=True,
// the user will always be expected to pass N output pointers -
// extra copies will occur if the outputs are views of constants,
// inputs, or other outputs in this case to avoid surprises.
//
// Use stream = nullptr for default stream. ModelContainer/Model does not
// create or own any stream. The user is expected to create and manage streams.
//
// We can support at most num_models concurrent inferences.
// Run() takes a stream to run the inference on. For example,
// to start up two inferences on different streams concurrently,
// we can do this:
//
// model_container.Run(inputs0, num_inputs, outputs0, num_ouputs, stream0, ...);
// model_container.Run(inputs1, num_inputs, outputs1, num_ouputs, stream1, ...);
// StreamSynchronize(stream0);
// StreamSynchronize(stream1);
//
// Note that if there are no models available for inference, Run() will block
// until one becomes available.
//
// ModelContainer optionally takes an allocator argument, which it will use to
// allocate the space for the buffers used for intermediate tensors and
// constants. If it is nullptr, the default allocator will be used (e.g. just
// {cuda/hip}{Malloc/Free}).
// Important: we assume that the allocator lives until the ModelContainer is
// destroyed. The default allocator has a static lifetime.
class ModelContainer : ModelContainerBase {
public:
ModelContainer(
size_t num_models,
size_t blob_size,
size_t workspace_size,
size_t num_inputs,
size_t num_outputs,
size_t num_unbound_constants,
size_t params_size,
AITemplateAllocator& allocator);
void Run(
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
StreamType stream,
bool sync,
bool graph_mode,
int64_t** output_shapes_out);
void RunWithOutputsOnHost(
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
StreamType stream,
bool graph_mode,
int64_t** output_shapes_out);
float Benchmark(
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
StreamType stream,
bool graph_mode,
size_t count,
size_t num_threads,
bool use_unique_stream_per_thread,
int64_t** output_shapes_out);
void SetConstant(const char* name, const AITData& tensor);
size_t NumInputs() const;
size_t NumOutputs() const;
const char* InputName(size_t input_idx) const;
const char* OutputName(size_t output_idx) const;
AITemplateParamShape MaxOutputShape(size_t output_idx) const;
AITemplateDtype OutputDtype(size_t output_idx) const;
size_t MaxOutputStorageBytes(size_t output_idx) const;
size_t GetNumRuntimes() const {
return models_.size();
}
private:
void PrepareForRun(
Model* model,
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs);
Model* GetAvailableModel();
void ReclaimFinishedModels(std::unique_lock<std::mutex>& lk);
void ValidateDtype(AITemplateDtype dtype, size_t idx) const;
float BenchmarkImpl(
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
StreamType stream,
bool graph_mode,
size_t count,
int64_t** output_shapes_out);
AITemplateAllocator& allocator_;
std::vector<Model> models_;
std::vector<Model*> available_models_;
std::deque<Model*> pending_models_;
// Guards accesses to available/pending models.
std::mutex models_mutex_;
// Notified whenever a model is put into pending_models_.
std::condition_variable pending_models_available_;
size_t num_inputs_;
size_t num_outputs_;
};
} // namespace ait
#include "model_container.h"
#include "owned_constants.h"
namespace ait {
namespace {
// Contains the metadata for each constant.
constexpr std::array<ConstantInfo, 0> owned_constants = {
};
} // namespace
ModelContainerBase::ModelContainerBase(
size_t num_inputs,
size_t num_outputs,
size_t num_unbound_constants,
size_t params_size,
AITemplateAllocator& allocator)
: constants_(RAII_DeviceMalloc(params_size, allocator)),
num_params_(num_inputs + num_outputs + num_unbound_constants),
param_names_(num_params_),
param_dtypes_(num_params_),
max_param_shapes_(num_params_),
max_param_numel_(num_params_),
max_param_storage_bytes_(num_params_) {
param_names_[0] = "input_0";
param_names_[1] = "input_1";
param_names_[2] = "output_0";
param_dtypes_[0] = AITemplateDtype::kHalf;
param_dtypes_[1] = AITemplateDtype::kHalf;
param_dtypes_[2] = AITemplateDtype::kHalf;
max_param_shapes_[0] = {256, 128};
max_param_shapes_[1] = {128, 32};
max_param_shapes_[2] = {256, 32};
for (size_t i = 0; i < num_params_; ++i) {
max_param_numel_[i] = std::accumulate(
max_param_shapes_[i].begin(),
max_param_shapes_[i].end(),
1,
std::multiplies<int64_t>()
);
max_param_storage_bytes_[i] = max_param_numel_[i] * AITemplateDtypeSizeBytes(param_dtypes_[i]);
}
auto* constants_ptr = static_cast<uint8_t*>(constants_.get());
const auto binary_constants_bin_size = static_cast<size_t>(_binary_constants_bin_end - _binary_constants_bin_start);
for (auto& constant_info : owned_constants) {
auto* dst = constants_ptr + constant_info.internal_offset;
if (constant_info.data_offset + constant_info.num_bytes > binary_constants_bin_size) {
throw std::runtime_error(std::string("Copying constant ") + constant_info.name + " would overflow constant buffer");
}
DEVICE_CHECK(CopyToDevice(dst, _binary_constants_bin_start + constant_info.data_offset, constant_info.num_bytes));
}
}
ModelContainer* CreateModelContainer(size_t num_runtimes, AITemplateAllocator& allocator) {
// num_runtimes, blob_size, workspace_size, num_inputs, num_outputs, num_unbound_constants, param_size, allocator
return new ModelContainer(num_runtimes, 90112, 0, 2, 1, 0, 0, allocator);
}
} // namespace ait
#include "model_interface.h"
#include <iostream>
#include <unordered_map>
#include "model-generated.h"
#include "model_container.h"
// Important: don't let exceptions escape the functions below.
// They can cause problems when -fvisibility=hidden. But more
// importantly, they can crash the program if they try to cross
// the language boundary into Python.
#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \
try { \
__VA_ARGS__ \
} catch (const std::exception& e) { \
LOG(ERROR) << "Error: " << e.what(); \
return AITemplateError::AITemplateFailure; \
} catch (...) { \
LOG(ERROR) << "Unknown exception occurred."; \
return AITemplateError::AITemplateFailure; \
} \
return AITemplateError::AITemplateSuccess;
#define RETURN_ERROR_IF_NULL(var) \
if (var == nullptr) { \
LOG(ERROR) << "Variable " << #var << " can't be null"; \
return AITemplateError::AITemplateFailure; \
}
namespace ait {
namespace {
class DefaultAllocator : public AITemplateAllocator {
public:
void* Allocate(size_t n_bytes) override {
void* result;
DEVICE_CHECK(DeviceMalloc(&result, n_bytes));
return result;
}
void Free(void* ptr) override {
DEVICE_CHECK(FreeDeviceMemory(ptr));
}
};
class TrackingAllocator : public DefaultAllocator {
public:
void* Allocate(size_t n_bytes) override {
auto* result = DefaultAllocator::Allocate(n_bytes);
num_bytes_ += n_bytes;
return result;
}
size_t NumBytesAllocated() const {
return num_bytes_;
}
private:
size_t num_bytes_ = 0;
};
DefaultAllocator default_allocator;
} // namespace
} // namespace ait
extern "C" {
AITemplateError AITemplateModelContainerCreate(
AITemplateModelHandle* ret,
size_t num_runtimes,
AITemplateAllocator* allocator) {
if (num_runtimes == 0) {
LOG(ERROR) << "num_runtimes must be positive, but got 0";
return AITemplateError::AITemplateFailure;
}
RETURN_ERROR_IF_NULL(ret)
AITemplateAllocator& allocator_ref =
allocator == nullptr ? ait::default_allocator : *allocator;
CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* m = ait::CreateModelContainer(num_runtimes, allocator_ref);
*ret = reinterpret_cast<AITemplateModelHandle>(m);
})
}
AITemplateError AITemplateModelContainerDelete(AITemplateModelHandle handle) {
RETURN_ERROR_IF_NULL(handle)
CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* m = reinterpret_cast<ait::ModelContainer*>(handle);
delete m;
});
}
AITemplateError AITemplateModelContainerSetConstant(
AITemplateModelHandle handle,
const char* name,
const AITData* tensor) {
RETURN_ERROR_IF_NULL(handle)
RETURN_ERROR_IF_NULL(tensor)
auto* m = reinterpret_cast<ait::ModelContainer*>(handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({ m->SetConstant(name, *tensor); })
}
AITemplateError AITemplateModelContainerRun(
AITemplateModelHandle handle,
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
AITemplateStreamHandle stream_handle,
bool sync,
bool graph_mode,
int64_t** output_shapes_out) {
RETURN_ERROR_IF_NULL(handle)
auto* m = reinterpret_cast<ait::ModelContainer*>(handle);
auto stream = reinterpret_cast<ait::StreamType>(stream_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
m->Run(
inputs,
num_inputs,
outputs,
num_outputs,
stream,
sync,
graph_mode,
output_shapes_out);
})
}
AITemplateError AITemplateModelContainerRunWithOutputsOnHost(
AITemplateModelHandle handle,
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
AITemplateStreamHandle stream_handle,
bool graph_mode,
int64_t** output_shapes_out) {
RETURN_ERROR_IF_NULL(handle)
auto* m = reinterpret_cast<ait::ModelContainer*>(handle);
auto stream = reinterpret_cast<ait::StreamType>(stream_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
m->RunWithOutputsOnHost(
inputs,
num_inputs,
outputs,
num_outputs,
stream,
graph_mode,
output_shapes_out);
})
}
AITemplateError AITemplateModelContainerBenchmark(
AITemplateModelHandle handle,
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
AITemplateStreamHandle stream_handle,
bool graph_mode,
size_t count,
size_t num_threads,
bool use_unique_stream_per_thread,
float* runtime_ms,
int64_t** output_shapes_out) {
RETURN_ERROR_IF_NULL(handle)
RETURN_ERROR_IF_NULL(runtime_ms)
auto* m = reinterpret_cast<ait::ModelContainer*>(handle);
auto stream = reinterpret_cast<ait::StreamType>(stream_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
*runtime_ms = m->Benchmark(
inputs,
num_inputs,
outputs,
num_outputs,
stream,
graph_mode,
count,
num_threads,
use_unique_stream_per_thread,
output_shapes_out);
})
}
AITemplateError AITemplateModelContainerGetNumInputs(
AITemplateModelHandle handle,
size_t* num_inputs_out) {
RETURN_ERROR_IF_NULL(handle)
RETURN_ERROR_IF_NULL(num_inputs_out)
auto* m = reinterpret_cast<ait::ModelContainer*>(handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({ *num_inputs_out = m->NumInputs(); })
}
AITemplateError AITemplateModelContainerGetInputName(
AITemplateModelHandle handle,
size_t input_idx,
const char** input_name_out) {
RETURN_ERROR_IF_NULL(handle)
RETURN_ERROR_IF_NULL(input_name_out)
auto* m = reinterpret_cast<ait::ModelContainer*>(handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *input_name_out = m->InputName(input_idx); })
}
AITemplateError AITemplateModelContainerGetNumOutputs(
AITemplateModelHandle handle,
size_t* num_outputs_out) {
RETURN_ERROR_IF_NULL(handle)
RETURN_ERROR_IF_NULL(num_outputs_out)
auto* m = reinterpret_cast<ait::ModelContainer*>(handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({ *num_outputs_out = m->NumOutputs(); })
}
AITemplateError AITemplateModelContainerGetOutputName(
AITemplateModelHandle handle,
size_t output_idx,
const char** output_name_out) {
RETURN_ERROR_IF_NULL(handle)
RETURN_ERROR_IF_NULL(output_name_out)
auto* m = reinterpret_cast<ait::ModelContainer*>(handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *output_name_out = m->OutputName(output_idx); })
}
AITemplateError AITemplateModelContainerGetMaximumOutputShape(
AITemplateModelHandle handle,
size_t output_idx,
AITemplateParamShape* shape_out) {
RETURN_ERROR_IF_NULL(handle)
RETURN_ERROR_IF_NULL(shape_out)
auto* m = reinterpret_cast<ait::ModelContainer*>(handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *shape_out = m->MaxOutputShape(output_idx); })
}
AITemplateError AITemplateModelContainerGetOutputDtype(
AITemplateModelHandle handle,
size_t output_idx,
AITemplateDtype* dtype_out) {
RETURN_ERROR_IF_NULL(handle)
RETURN_ERROR_IF_NULL(dtype_out)
auto* m = reinterpret_cast<ait::ModelContainer*>(handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({ *dtype_out = m->OutputDtype(output_idx); })
}
AITemplateError AITemplateModelContainerGetNumRuntimes(
AITemplateModelHandle handle,
size_t* num_runtimes_out) {
RETURN_ERROR_IF_NULL(num_runtimes_out)
auto* m = reinterpret_cast<ait::ModelContainer*>(handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({ *num_runtimes_out = m->GetNumRuntimes(); })
}
AITemplateError AITemplateAllocatorCreate(
AITemplateAllocator** allocator_out,
AITemplateAllocatorType allocator_type) {
RETURN_ERROR_IF_NULL(allocator_out);
CONVERT_EXCEPTION_TO_ERROR_CODE({
switch (allocator_type) {
case AITemplateAllocatorType::kDefault:
*allocator_out = new ait::DefaultAllocator();
break;
case AITemplateAllocatorType::kTracking:
*allocator_out = new ait::TrackingAllocator();
break;
default:
throw std::runtime_error("Unrecognized allocator type");
}
});
}
AITemplateError AITemplateAllocatorDelete(AITemplateAllocator* allocator) {
RETURN_ERROR_IF_NULL(allocator);
delete allocator;
return AITemplateError::AITemplateSuccess;
}
AITemplateError AITemplateTrackingAllocatorGetNumBytes(
AITemplateAllocator* allocator,
size_t* num_bytes_out) {
RETURN_ERROR_IF_NULL(allocator);
RETURN_ERROR_IF_NULL(num_bytes_out);
CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* tracking_allocator = dynamic_cast<ait::TrackingAllocator*>(allocator);
if (tracking_allocator == nullptr) {
throw std::runtime_error("Allocator was not a tracking allocator!");
}
*num_bytes_out = tracking_allocator->NumBytesAllocated();
});
}
} // extern "C"
\ No newline at end of file
#pragma once
#include <stddef.h>
#include <stdint.h>
#include <numeric>
#include <stdexcept>
#include <utility>
#include <vector>
// We compile all models with -fvisibility=hidden. Any symbols that need to be
// exposed in the final shared library must be declared with AIT_EXPORT to make
// them visible.
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
#define AIT_EXPORT __attribute__((__visibility__("default")))
#else
#ifdef _WIN32
#define AIT_EXPORT __declspec(dllexport)
#else
#define AIT_EXPORT
#endif
#endif
struct AITemplateModelOpaque {};
using AITemplateModelHandle = AITemplateModelOpaque*;
enum class AITemplateError : int {
AITemplateSuccess = 0,
AITemplateFailure = 1,
};
struct AITemplateParamShape {
AITemplateParamShape() : shape_data(nullptr), size(0) {}
AITemplateParamShape(const int64_t* shape_data_in, size_t size_in)
: shape_data(shape_data_in), size(size_in) {}
const int64_t* shape_data;
size_t size;
size_t Numel() const {
return std::accumulate(
shape_data, shape_data + size, 1, std::multiplies<int64_t>());
}
};
enum class AITemplateDtype {
kUnset = 0,
kHalf,
kFloat,
kInt,
kLong,
kBool,
};
struct AITData {
AITData() : ptr(nullptr), dtype(AITemplateDtype::kUnset) {}
AITData(
void* ptr_in,
const AITemplateParamShape& shape_in,
AITemplateDtype dtype_in)
: ptr(ptr_in), shape(shape_in), dtype(dtype_in) {}
void* ptr;
AITemplateParamShape shape;
AITemplateDtype dtype;
};
inline size_t AITemplateDtypeSizeBytes(AITemplateDtype dtype) {
switch (dtype) {
case AITemplateDtype::kHalf:
return 2;
case AITemplateDtype::kFloat:
return 4;
case AITemplateDtype::kInt:
return 4;
case AITemplateDtype::kLong:
return 8;
case AITemplateDtype::kBool:
return 1;
case AITemplateDtype::kUnset:
throw std::runtime_error("Unset dtype has no size!");
}
}
struct AITemplateStreamOpaque {};
using AITemplateStreamHandle = AITemplateStreamOpaque*;
// Allocator to use for GPU mallocs and frees. Allocations will only happen
// when the ModelContainer is created.
class AITemplateAllocator {
public:
virtual void* Allocate(size_t nbytes) = 0;
virtual void Free(void* ptr) = 0;
virtual ~AITemplateAllocator() = default;
};
// Some custom allocators are provided. They can be created by passing
// an enum into the AITemplateAllocatorCreate() function.
enum class AITemplateAllocatorType {
// The default allocator just uses the backend's default malloc/free.
kDefault = 0,
// The tracking allocator is like the default allocator, but it keeps
// track of how many bytes it has allocated. Mainly used for testing.
kTracking,
};
extern "C" {
// Create a ModelContainer. See model_container.h for all the details.
// Some important high-level notes:
// * If allocator is null, a default allocator is used (forwards to
// {cuda/hip}{Malloc/Free}).
// * We assume that the allocator lives at least as long as the ModelContainer.
AIT_EXPORT AITemplateError AITemplateModelContainerCreate(
AITemplateModelHandle* ret,
size_t num_runtimes,
AITemplateAllocator* allocator = nullptr);
AIT_EXPORT AITemplateError
AITemplateModelContainerDelete(AITemplateModelHandle handle);
AIT_EXPORT AITemplateError AITemplateModelContainerSetConstant(
AITemplateModelHandle handle,
const char* name,
const AITData* tensor);
AIT_EXPORT AITemplateError AITemplateModelContainerRun(
AITemplateModelHandle handle,
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
AITemplateStreamHandle stream_handle,
bool sync,
bool graph_mode,
int64_t** output_shapes_out);
// Like AITemplateModelContainerRun, but expects outputs to be allocated on the
// host. Does an extra sync/copy at the end to copy them over. Warning: don't
// use this! It's not optimal with respect to performance. It's here for use by
// internal constant folding passes.
AIT_EXPORT AITemplateError AITemplateModelContainerRunWithOutputsOnHost(
AITemplateModelHandle handle,
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
AITemplateStreamHandle stream_handle,
bool graph_mode,
int64_t** output_shapes_out);
AIT_EXPORT AITemplateError AITemplateModelContainerBenchmark(
AITemplateModelHandle handle,
const AITData* inputs,
size_t num_inputs,
AITData* ouputs,
size_t num_outputs,
AITemplateStreamHandle stream_handle,
bool graph_mode,
size_t count,
size_t num_threads,
bool use_unique_stream_per_thread,
float* runtime_ms,
int64_t** output_shapes_out);
AIT_EXPORT AITemplateError AITemplateModelContainerGetNumInputs(
AITemplateModelHandle handle,
size_t* num_inputs_out);
AIT_EXPORT AITemplateError AITemplateModelContainerGetInputName(
AITemplateModelHandle handle,
size_t input_idx,
const char** input_name_out);
AIT_EXPORT AITemplateError AITemplateModelContainerGetNumOutputs(
AITemplateModelHandle handle,
size_t* num_outputs_out);
AIT_EXPORT AITemplateError AITemplateModelContainerGetOutputName(
AITemplateModelHandle handle,
size_t output_idx,
const char** output_name_out);
AIT_EXPORT AITemplateError AITemplateModelContainerGetMaximumOutputShape(
AITemplateModelHandle handle,
size_t output_idx,
AITemplateParamShape* shape_out);
AIT_EXPORT AITemplateError AITemplateModelContainerGetOutputDtype(
AITemplateModelHandle handle,
size_t output_idx,
AITemplateDtype* out);
AIT_EXPORT AITemplateError AITemplateModelContainerGetNumRuntimes(
AITemplateModelHandle handle,
size_t* num_runtimes_out);
AIT_EXPORT AITemplateError AITemplateAllocatorCreate(
AITemplateAllocator** allocator_out,
AITemplateAllocatorType allocator_type);
AIT_EXPORT AITemplateError
AITemplateAllocatorDelete(AITemplateAllocator* allocator_out);
// Get the number of bytes allocated; mainly used for testing.
AIT_EXPORT AITemplateError AITemplateTrackingAllocatorGetNumBytes(
AITemplateAllocator* allocator,
size_t* num_bytes_out);
} // extern "C"
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Normalization common codegen for ROCM.
"""
import os
import re
from hashlib import sha1
from typing import Any, Dict, OrderedDict
import jinja2
from ...target import Target
FUNC_CALL_PARAM_TEMPLATE = jinja2.Template("(void *)({{name}})")
INSTANCE_TEMPLATE = jinja2.Template(
"""
{{config}}
using {{name}} = {{ config_name }};
"""
)
ARGS_PARSE_TEMPLATE = jinja2.Template(
"""
{% for idx in range(rank) %}
const int64_t in_{{idx}} = std::stoi(argv[{{ idx + 1 }}]);
{% endfor %}
"""
)
STRUCTS_DEF_TEMPLATE = jinja2.Template(
"""
struct ProfilerMemoryPool {
ProfilerMemoryPool() {
std::random_device rd;
gen = std::mt19937(rd());
uniform_dist = std::uniform_int_distribution<int64_t>(1, 48964896);
offsets.reserve(512);
strides.reserve(512);
copies.reserve(512);
ptrs.reserve(512);
}
~ProfilerMemoryPool() {
for(int i = 0; i < ptrs.size(); i++){
hipFree(ptrs[i]);
}
}
template <typename DType>
DType* AllocateGaussianTensor(int64_t size) {
size_t length = size * sizeof(DType);
DType *d_x;
hipMalloc(&d_x, length);
float mean = 0.0f;
float stddev = 1.0f;
uint64_t seed = uniform_dist(gen);
rocrand_set_seed(generator, seed);
rocrand_generate_normal(generator, reinterpret_cast<float*>(d_x), size, mean, stddev);
return d_x;
}
ck::half_t* AllocateHalfGaussianTensor(int64_t size) {
return reinterpret_cast<ck::half_t*>(
AllocateGaussianTensor<ck::half_t>(size));
}
int AllocateHalfTensor(int64_t size, int64_t copy) {
offsets.push_back(0);
strides.push_back(size);
copies.push_back(copy);
auto ptr = AllocateHalfGaussianTensor(size * copy);
ptrs.push_back(reinterpret_cast<void*>(ptr));
return ptrs.size() - 1;
}
ck::half_t* RequestHalfTensorByIdx(int idx) {
auto copy = copies.at(idx);
auto offset = offsets.at(idx);
auto stride = strides.at(idx);
ck::half_t* ptr = reinterpret_cast<ck::half_t*>(ptrs.at(idx));
ptr += offset;
offset += stride;
if (offset == copy * stride) {
offset = 0;
}
offsets[idx] = offset;
return ptr;
}
std::vector<int64_t> offsets;
std::vector<int64_t> strides;
std::vector<int64_t> copies;
std::vector<void*> ptrs;
std::mt19937 gen;
std::uniform_int_distribution<int64_t> uniform_dist;
rocrand_generator generator;
};
// hack for DeviceMem linking error
// TODO fix this by making CK a header-only lib
// <<< hack begin
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p) const
{
hipGetErrorString(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
{
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
struct KernelTimerImpl
{
KernelTimerImpl() {
hipGetErrorString(hipEventCreate(&mStart));
hipGetErrorString(hipEventCreate(&mEnd));
}
~KernelTimerImpl() {
hipGetErrorString(hipEventDestroy(mStart));
hipGetErrorString(hipEventDestroy(mEnd));
}
void Start() {
hipGetErrorString(hipDeviceSynchronize());
hipGetErrorString(hipEventRecord(mStart, nullptr));
}
void End() {
hipGetErrorString(hipEventRecord(mEnd, nullptr));
hipGetErrorString(hipEventSynchronize(mEnd));
}
float GetElapsedTime() const {
float time;
hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd));
return time;
}
hipEvent_t mStart, mEnd;
};
// >>> hack end
"""
)
PROFILER_TEMPLATE = jinja2.Template(
"""
size_t GLOBAL_WORKSPACE_SIZE = 0;
{{op_func}}
{{structs_def}}
int main(int argc, char** argv) {
{{args_parse}}
auto memory_pool = std::make_unique<ProfilerMemoryPool>();
hipStream_t stream = nullptr;
{{tensor_decl}}
// warmup
for(int i = 0; i < 3; ++i) {
{{func_call}}
}
// run
KernelTimerImpl timer;
timer.Start();
for(int i = 0; i < 5; ++i) {
{{func_call}}
}
timer.End();
std::cout << "WS:" <<GLOBAL_WORKSPACE_SIZE<<std::endl;
std::cout << "TIME:" << timer.GetElapsedTime() << std::endl;
}
"""
)
FUNC_TEMPLATE = jinja2.Template(
"""
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <random>
#include <rocrand/rocrand.h>
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/reduction_operator.hpp"
{{extra_headers}}
{{extra_code}}
{{instances_decl}}
{{func_signature}}
{
{{shape_eval}}
{{exec_paths}}
}
"""
)
FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{{func_name}}(
{{indent}} {{input}},
{{indent}} {{output}},
{% for name in input_dim_names %}
{{indent}} const_cast<int64_t *>(&{{name}}),
{% endfor %}
{{indent}} stream
{{indent}});
"""
)
def extract_config(func_attrs):
"""Extract (operation name, operation instance) pair
from all operation candidates.
Parameters
----------
op_kind : ck_lib.library.OperationKind
Operation kind.
extra_kind : ck_lib.library.[AnyKind]
Used to as extra flag to distinguish kernels.
E.g. bias_add_relu vs. add_relu_bias
f_prop_op: function
Used to filter operation.
Returns
-------
Dict
Extracted (operation name, operation instance) pair.
"""
import ck_lib
op_kind = ck_lib.library.OperationKind.Softmax
extra_kind = len(func_attrs["inputs"][0]._attrs["shape"])
extract_ops = list(Target.current()._operators[op_kind][extra_kind].items())
softmax_ops = OrderedDict()
for key, value in extract_ops:
softmax_ops[key] = value[0]
func_attrs["op_instance"] = softmax_ops
def emit_instance(op):
"""Emit instance"""
import ck_lib # noqa: F401
op_def = op.emit()
return op_def
def extract_config_name(config):
"""Extract configuration names.
Parameters
----------
config : str
Configuration as a string in the format of 'using model = xxx'.
Returns
-------
str
Extracted name from the statement, e.g. 'model' for 'using model = xxx'.
Raises
------
RuntimeError
Invalid config.
"""
pattern = re.compile(r"\s*using\s(.*?)\s=")
decl = config.split("\n")[1]
match = pattern.match(decl)
if match is None:
raise RuntimeError("Invalid config: \n" + config)
return match.groups()[0]
def gen_profiler(
func_attrs: Dict[str, Any],
workdir: str,
rank: int,
shape_eval_template: jinja2.Template,
exec_template: jinja2.Template,
tensor_decl_template: jinja2.Template,
extra_header_template: jinja2.Template,
get_func_signature: Any,
extra_code: str = "",
func_call_template: jinja2.Template = FUNC_CALL_TEMPLATE,
indent: str = " ",
) -> str:
"""Generates standalone executables for profiler.
Parameters
----------
func_attrs : Dict
Operation attributes.
workdir : str
Directory to store the generated outputs.
rank: int
Rank of the input tensor. If using [M, N] in exec_key, the rank here
must be 2 because if implies that the inputs are reshaped for profiling.
For code gen, the real shapes are used.
exec_template : jinja2.Template
Execution block template.
tensor_decl_template: jinja2.Template
Tensor declaration template.
extra_header_template : jinja2.Template
Extra header template.
indent : str, optional
Indent for codegen, target dependent e.g. C++, python, etc., by default " ".
"""
op_type = func_attrs["op"]
shape_eval = shape_eval_template.render(rank=rank) if shape_eval_template else ""
eps = func_attrs.get("eps", "1e-5")
op_instance = func_attrs["op_instance"]
file_pairs = []
for op_name, op in op_instance.items():
config = emit_instance(op)
config_name = extract_config_name(config)
instances = INSTANCE_TEMPLATE.render(
name="DeviceInstance", config_name=config_name, config=config
)
exe_path = exec_template.render(
instance="DeviceInstance",
dtype="void",
reduce_dims=rank - 1,
rank=rank,
eps=eps,
)
op_func = FUNC_TEMPLATE.render(
instances_decl=instances,
func_signature=get_func_signature(func_attrs),
shape_eval=shape_eval,
exec_paths=exe_path,
extra_headers=extra_header_template.render(),
extra_code=extra_code,
)
structs_def = STRUCTS_DEF_TEMPLATE.render()
args_parse = ARGS_PARSE_TEMPLATE.render(rank=rank)
tensor_decl = tensor_decl_template.render(rank=rank)
input_dim_names = [f"in_{i}" for i in range(rank)]
func_call = func_call_template.render(
func_name=func_attrs["name"],
input="(void *) memory_pool->RequestHalfTensorByIdx(0)",
gamma="(void *) memory_pool->RequestHalfTensorByIdx(2)",
beta="(void *) memory_pool->RequestHalfTensorByIdx(3)",
output="(void *) memory_pool->RequestHalfTensorByIdx(1)",
input_dim_names=input_dim_names,
indent=indent,
)
code = PROFILER_TEMPLATE.render(
op_func=op_func,
structs_def=structs_def,
args_parse=args_parse,
tensor_decl=tensor_decl,
func_call=func_call,
)
prefix = os.path.join(workdir, "profiler", op_type)
if not os.path.exists(prefix):
os.makedirs(prefix)
src_path = os.path.join(prefix, op_name + ".cpp")
obj_path = os.path.join(prefix, op_name)
if os.path.exists(obj_path):
continue
with open(src_path, "w") as fo:
fo.write(code)
file_pairs.append((src_path, obj_path))
return file_pairs
# no longer used by layernorm
def gen_function(
func_attrs: Dict[str, Any],
exec_template: jinja2.Template,
extra_header_template: jinja2.Template,
get_func_signature: Any,
) -> str:
"""Generate function body.
Parameters
----------
func_attrs : Dict
Operation attributes.
exec_template : jinja2.Template
Execution block template.
extra_header_template : jinja2.Template
Extra header template.
Returns
-------
str
The rendered template of generated function body.
"""
shapes = func_attrs["inputs"][0]._attrs["shape"]
rank = len(shapes)
exec_path = func_attrs["exec_path"]
op_instance = func_attrs["op_instance"]
inst_def_flag = set()
instances = {}
instance_decl = ""
for exec_item in exec_path.values():
fname = "f" + sha1(exec_item.exec_cond.encode()).hexdigest()
algo = exec_item.algo
if algo not in inst_def_flag:
config = emit_instance(op_instance[algo])
inst_def_flag.add(algo)
else:
config = ""
inst = INSTANCE_TEMPLATE.render(
config=config, name=fname, config_name=extract_config_name(config)
)
instances[exec_item.exec_cond] = inst
instance_decl += inst
exec_cond_template = func_attrs["exec_cond_template"]
exec_paths = ""
for key, _ in instances.items():
fname = "f" + sha1(key.encode()).hexdigest()
program = exec_template.render(
instance=fname, dtype="void", reduce_dims=rank - 1, rank=rank
)
cond_vars = re.findall(r"\S+(?= >=)", key)
cond_vars += re.findall(r"\S+(?= ==)", key)
cond = key
for i, var in enumerate(cond_vars):
cond = cond.replace(var + " ", "*in_" + str(i))
exec_inst = exec_cond_template.render(indent=" ", cond=cond, program=program)
exec_paths += exec_inst
return FUNC_TEMPLATE.render(
instances_decl=instance_decl,
func_signature=get_func_signature(func_attrs),
exec_paths=exec_paths,
extra_headers=extra_header_template.render(),
)
def gen_function_call(func_attrs, indent=" "):
"""Generates function call.
Parameters
----------
func_attrs : Dict
Stores the operation attributes.
indent : str, optional
Indent for codegen, target dependent e.g. C++, python, etc., by default " ".
Returns
-------
str
The rendered template of generated function call.
"""
input_name = FUNC_CALL_PARAM_TEMPLATE.render(
name=func_attrs["inputs"][0]._attrs["name"]
)
output_name = FUNC_CALL_PARAM_TEMPLATE.render(
name=func_attrs["outputs"][0]._attrs["name"]
)
shapes = func_attrs["inputs"][0]._attrs["shape"]
input_dim_names = [shape._attrs["name"] for shape in shapes]
return FUNC_CALL_TEMPLATE.render(
func_name=func_attrs["name"],
input=input_name,
output=output_name,
input_dim_names=input_dim_names,
indent=indent,
)
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import jinja2
EXTRA_SHAPE_TEMPLATE = jinja2.Template(
"""
{{indent}}const int64_t stride_a = *a_dim1;
{{indent}}const int64_t stride_b = *b_dim1;
{{indent}}const int64_t stride_c = *c_dim1;
ck::index_t M0 = M / G1 / G2;
ck::index_t M1 = G1;
ck::index_t M2 = G2;
ck::index_t N0 = G3;
ck::index_t N1 = N / G3;
// GEMM shape
//ck::index_t M = M0 * M1 * M2;
//ck::index_t N = N0 * N1;
//ck::index_t K = 128;
//ck::index_t stride_A = K;
//ck::index_t stride_B = K;
// E = [M0, N0, M1, N1, M2]
/* 0, 3, 1, 4, 2
ck::index_t stride_E_M0 = N0 * M1 * N1 * M2;
ck::index_t stride_E_M1 = N1 * M2;
ck::index_t stride_E_M2 = 1;
ck::index_t stride_E_N0 = M1 * N1 * M2;
ck::index_t stride_E_N1 = M2;
*/
// E = [M2, M0, N0, M1, N1] 2, 0, 3, 1, 4
ck::index_t stride_E_M0 = N0* M1* N1;
ck::index_t stride_E_M1 = N1;
ck::index_t stride_E_M2 = M0* N0* M1* N1;
ck::index_t stride_E_N0 = M1 * N1;
ck::index_t stride_E_N1 = 1;
// D = [0, N0, 0, N1, 0]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
"""
)
EXTRA_SHAPE_TEMPLATE_M2N3 = jinja2.Template(
"""
const int64_t G1 = p_dim0; // G1
const int64_t G2 = p_dim1; // G2
const int64_t G3 = p_dim2; // G3
ck::index_t M0 = M / G1;
ck::index_t M1 = G1;
ck::index_t N0 = G2;
ck::index_t N1 = G3;
ck::index_t N2 = N / G2 / G3;
ck::index_t K0 = K;
ck::index_t G = 1;
// A[G, M0, M1, M2, K0]
std::vector<ck::index_t> a_ms_ks_lengths{G, M0, M1, K0};
std::vector<ck::index_t> a_ms_ks_strides{M0*M1*K0, M1 * K0, K0, 1};
// B[G, N0, N1, K0]
std::vector<ck::index_t> b_ns_ks_lengths{G, N0, N1, N2, K0};
std::vector<ck::index_t> b_ns_ks_strides{N0*N1*N2*K0, N1 * N2 * K0, N2 * K0, K0, 1};
// D[G, N0, M0, N1, M1, N2]
std::vector<ck::index_t> d_ms_ns_lengths{G, M0, M1, N0, N1, N2};
std::vector<ck::index_t> d_ms_ns_strides{N0 * N1 * N2, 0, 0, N1 * N2, N2, 1};
// E[G, N0, M0, N1, M1, N2] 2, 0, 3, 1, 4
std::vector<ck::index_t> e_ms_ns_lengths{G, M0, M1, N0, N1, N2};
std::vector<ck::index_t> e_ms_ns_strides{M0* M1* N0* N1* N2,
N1 * M1 * N2,
N2,
M0 * N1 * M1 * N2,
M1 * N2,
1};
"""
)
EXTRA_SHAPE_TEMPLATE_M3N2 = jinja2.Template(
"""
const int64_t G1 = p_dim0; // G1
const int64_t G2 = p_dim1; // G2
const int64_t G3 = p_dim2; // G3
ck::index_t M0 = M / G1 / G2;
ck::index_t M1 = G1;
ck::index_t M2 = G2;
ck::index_t N0 = G3;
ck::index_t N1 = N / G3;
ck::index_t K0 = K;
ck::index_t G = 1;
// A[M0, M1, M2, K0]
std::vector<ck::index_t> a_ms_ks_lengths{G, M0, M1, M2, K0};
std::vector<ck::index_t> a_ms_ks_strides{M0 * M1 * M2 * K0, M1 * M2 * K0, M2 * K0, K0, 1};
// B[N0, N1, K0]
std::vector<ck::index_t> b_ns_ks_lengths{G, N0, N1, K0};
std::vector<ck::index_t> b_ns_ks_strides{N0 * N1 * K0, N1 * K0, K0, 1};
// D[M0, N0, M1, N1, M2]
std::vector<ck::index_t> d_ms_ns_lengths{G, M0, M1, M2, N0, N1};
std::vector<ck::index_t> d_ms_ns_strides{N0*N1, 0, 0, 0, N1, 1};
// E[M0, N0, M1, N1, M2]
std::vector<ck::index_t> e_ms_ns_lengths{G, M0, M1, M2, N0, N1};
std::vector<ck::index_t> e_ms_ns_strides{M0 * M1* M2 * N1* N0, N0* M1* N1, N1, M0* N0* M1* N1, M1 * N1, 1};
"""
)
\ No newline at end of file
......@@ -5,12 +5,12 @@ CXXFLAGS = -std=c++17
gemm: ex.o host_tensor.o device_memory.o
hipcc $(CXXFLAGS) $(CFLAGS) ex.o host_tensor.o device_memory.o -o gemm
device_memory.o: ../../../../library/src/utility/device_memory.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../../../library/src/utility/device_memory.cpp
device_memory.o: ../../library/src/utility/device_memory.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../library/src/utility/device_memory.cpp
host_tensor.o: ../../../../library/src/utility/host_tensor.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../../../library/src/utility/host_tensor.cpp
host_tensor.o: ../../library/src/utility/host_tensor.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../library/src/utility/host_tensor.cpp
ex.o:
hipcc -fPIC -fvisibility=hidden $(CXXFLAGS) -w /opt/rocm-5.3.0/amdgcn/bitcode/oclc_abi_version_400.bc $(CFLAGS) -L/opt/rocm-5.3.0/rocrand -lrocrand -x hip -c ex.cpp
hipcc -fPIC -fvisibility=hidden $(CXXFLAGS) -w $(CFLAGS) -L/opt/rocm-5.3.0/rocrand -lrocrand -x hip -c ex.cpp
\ No newline at end of file
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl<
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::Default,
256,
64,
128,
8,
2,
2,
4,
1,
S<8, 2>,
S<8, 2>,
S<1, 1, 2, 2>,
S<8, 1, 32, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 2, 1>,
S<0, 3, 1, 2>,
S<1, 1, 2, 2>,
S<1, 1, 4, 2>,
S<8, 1, 32, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
4>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
bool run_gemm_256_64_128_8_2(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ck::half_t> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ck::tensor_layout::gemm::ColumnMajor{}));
Tensor<ck::half_t> b_k_n(f_host_tensor_descriptor(K, N, StrideB, ck::tensor_layout::gemm::RowMajor{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(b_k_n);
}
Tensor<ck::half_t> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ck::half_t> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ck::half_t) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(ck::half_t) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(ck::half_t) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto b_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto c_element_op = ck::tensor_operation::element_wise::PassThrough{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ck::half_t*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ck::half_t) * M * K + sizeof(ck::half_t) * K * N + sizeof(ck::half_t) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool run_gemm_example_256_64_128_8_2(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm_256_64_128_8_2(problem_size, config);
}
int main(int argc, char* argv[]) { return !run_gemm_example_256_64_128_8_2(argc, argv); }
......@@ -6,6 +6,8 @@ import operator
import collections
import subprocess
import re
import gemm_op
from gemm_op import *
def SubstituteTemplate(template, values):
text = template
......@@ -30,14 +32,14 @@ CXXFLAGS = -std=c++17
gemm: ex.o host_tensor.o device_memory.o
hipcc $(CXXFLAGS) $(CFLAGS) ex.o host_tensor.o device_memory.o -o gemm
device_memory.o: ../../../../library/src/utility/device_memory.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../../../library/src/utility/device_memory.cpp
device_memory.o: ../../library/src/utility/device_memory.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../library/src/utility/device_memory.cpp
host_tensor.o: ../../../../library/src/utility/host_tensor.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../../../library/src/utility/host_tensor.cpp
host_tensor.o: ../../library/src/utility/host_tensor.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../library/src/utility/host_tensor.cpp
ex.o:
hipcc -fPIC -fvisibility=hidden $(CXXFLAGS) -w /opt/rocm-5.3.0/amdgcn/bitcode/oclc_abi_version_400.bc $(CFLAGS) -L/opt/rocm-5.3.0/rocrand -lrocrand -x hip -c ex.cpp
hipcc -fPIC -fvisibility=hidden $(CXXFLAGS) -w $(CFLAGS) -L/opt/rocm-5.3.0/rocrand -lrocrand -x hip -c ex.cpp
"""
self.gemm_devop_template = """
#pragma once
......@@ -105,7 +107,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl<
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
bool run_gemm_${name}(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
......@@ -213,66 +215,69 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
return true;
}
bool run_gemm_example(int argc, char* argv[])
bool run_gemm_example_${name}(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm_${name}(problem_size, config);
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
int main(int argc, char* argv[]) { return !run_gemm_example_${name}(argc, argv); }
"""
def emit(self):
def emit(self,operation):
values = {
'type_a' : 'ck::half_t',
'type_b' : 'ck::half_t',
'type_c' : 'ck::half_t',
'name' : (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block)+ "_" + str(operation.tile_desc.k_per_block) + "_" + str(operation.tile_desc.k1)),
'type_a' : operation.A.element,
'type_b' : operation.B.element,
'type_c' : operation.C.element,
'type_acc' : 'float',
'layout_a' : 'ck::tensor_layout::gemm::ColumnMajor',
'layout_b' : 'ck::tensor_layout::gemm::RowMajor',
'layout_c' : 'ck::tensor_layout::gemm::RowMajor',
'elementwise_op_a' : 'ck::tensor_operation::element_wise::PassThrough',
'elementwise_op_b' : 'ck::tensor_operation::element_wise::PassThrough',
'elementwise_op_c' : 'ck::tensor_operation::element_wise::PassThrough',
'Gemm_spec' : 'ck::tensor_operation::device::GemmSpecialization::Default',
'block_size' : '256',
'mperblock' : '128',
'nperblock' : '128',
'k0perblock' : '16',
'k1' : '2',
'm1perthread' : '4',
'n1perthread' : '4',
'kperthread' : '1',
'm1n1_thcluster_m1xs' : 'S<8, 2>',
'm1n1_thcluster_n1xs' : 'S<8, 2>',
'ABT_thread_slice_lengths_K0_M0_M1_K1' : 'S<2, 1, 4, 2>',
'ABT_thread_cluster_lengths_K0_M0_M1_K1' : 'S<8, 1, 32, 1>',
'ABT_thread_cluster_arrange_order' : 'S<0, 3, 1, 2>',
'ABT_src_access_order' : 'S<0, 3, 1, 2>',
'ABT_src_vec_tensor_lengths_K0_M0_M1_K1' : 'S<1, 1, 4, 1>',
'ABT_src_vec_tensor_cont_dim_order' : 'S<0, 3, 1, 2>',
'ABT_dst_vec_tensor_lengths_K0_M0_M1_K1' : 'S<1, 1, 4, 2>',
'BBT_thread_slice_lengths_K0_N0_N1_K1' : 'S<2, 1, 4, 2>',
'BBT_thread_cluster_lengths_K0_N0_N1_K1' : 'S<8, 1, 32, 1>',
'BBT_thread_cluster_arrange_order' : 'S<0, 3, 1, 2>',
'BBT_src_access_order' : 'S<0, 3, 1, 2>',
'BBT_src_vec_tensor_lengths_K0_N0_N1_K1' : 'S<1, 1, 4, 1>',
'BBT_src_vec_tensor_cont_dim_order' : 'S<0, 3, 1, 2>',
'BBT_dst_vec_tensor_lengths_K0_N0_N1_K1': 'S<1, 1, 4, 2>',
'CTT_src_dst_access_order' : 'S<0, 1, 2, 3, 4, 5>',
'CTT_src_dst_vec_dim' : '5',
'CTT_dst_scalar_per_vector' : '4'
'layout_a' : operation.A.layout,
'layout_b' : operation.B.layout,
'layout_c' : operation.C.layout,
'elementwise_op_a' : operation.a_elem_op,
'elementwise_op_b' : operation.b_elem_op,
'elementwise_op_c' : operation.epilogue_functor,
'Gemm_spec' : operation.gemm_specialization,
'block_size' : str(operation.tile_desc.block_size),
'mperblock' : str(operation.tile_desc.m_per_block),
'nperblock' : str(operation.tile_desc.n_per_block),
'k0perblock' : str(operation.tile_desc.k_per_block),
'k1' : str(operation.tile_desc.k1),
'm1perthread' : str(operation.tile_desc.m_per_thread),
'n1perthread' : str(operation.tile_desc.n_per_thread),
'kperthread' : str(operation.tile_desc.k_per_thread),
'm1n1_thcluster_m1xs' : operation.tile_desc.m1n1_thcluster_m1xs,
'm1n1_thcluster_n1xs' : operation.tile_desc.m1n1_thcluster_n1xs,
'ABT_thread_slice_lengths_K0_M0_M1_K1' : operation.a_block_transfer.thread_slice_length,
'ABT_thread_cluster_lengths_K0_M0_M1_K1' : operation.a_block_transfer.thread_cluster_length,
'ABT_thread_cluster_arrange_order' : operation.a_block_transfer.thread_cluster_arrange_order,
'ABT_src_access_order' : operation.a_block_transfer.src_access_order,
'ABT_src_vec_tensor_lengths_K0_M0_M1_K1' : operation.a_block_transfer.src_vec_tensor_lengths,
'ABT_src_vec_tensor_cont_dim_order' : operation.a_block_transfer.src_vec_tensor_cont_dim_order,
'ABT_dst_vec_tensor_lengths_K0_M0_M1_K1' : operation.a_block_transfer.dst_vec_tensor_lengths,
'BBT_thread_slice_lengths_K0_N0_N1_K1' : operation.b_block_transfer.thread_slice_length,
'BBT_thread_cluster_lengths_K0_N0_N1_K1' : operation.b_block_transfer.thread_cluster_length,
'BBT_thread_cluster_arrange_order' : operation.b_block_transfer.thread_cluster_arrange_order,
'BBT_src_access_order' : operation.b_block_transfer.src_access_order,
'BBT_src_vec_tensor_lengths_K0_N0_N1_K1' : operation.b_block_transfer.src_vec_tensor_lengths,
'BBT_src_vec_tensor_cont_dim_order' : operation.b_block_transfer.src_vec_tensor_cont_dim_order,
'BBT_dst_vec_tensor_lengths_K0_N0_N1_K1': operation.b_block_transfer.dst_vec_tensor_lengths,
'CTT_src_dst_access_order' : operation.c_block_transfer.src_dst_access_order,
'CTT_src_dst_vec_dim' : str(operation.c_block_transfer.src_dst_vec_dim),
'CTT_dst_scalar_per_vector' : str(operation.c_block_transfer.dst_scalar_per_vector),
}
template = self.gemm_devop_template
name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block)
+ "_" + str(operation.tile_desc.k_per_block) + "_" + str(operation.tile_desc.k1))
template = self.gemm_devop_template
cf = open("ex.cpp", 'w')
print(SubstituteTemplate(template, values))
cf.write(SubstituteTemplate(template, values))
cf.close()
m_template = self.make_template
cf = open("Makefile", 'w')
print(SubstituteTemplate(m_template, values))
cf.write(SubstituteTemplate(m_template, values))
cf.close()
......@@ -288,6 +293,27 @@ int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
)
out, err = proc.communicate()
# defining an operation's parameters as input
A = TensorDesc(DataType.f16, Layout.ColumnMajor)
B = TensorDesc(DataType.f16, Layout.RowMajor)
C = TensorDesc(DataType.f16, Layout.RowMajor)
gemm = gemm_op.GemmOperation(
A=A,
B=B,
C=C,
a_elem_op=TensorOperation.PassThrough,
b_elem_op=TensorOperation.PassThrough,
epilogue_functor=TensorOperation.PassThrough,
gemm_specialization=GemmType.GemmDefault,
tile_desc=TileDesc(256, 64, 128, 8, 2, 2, 4, 1, "S<8, 2>", "S<8, 2>"),
a_block_transfer=BlockTransferDesc(
"S<1, 1, 2, 2>", "S<8, 1, 32, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 2, 1>", "S<0, 3, 1, 2>", "S<1, 1, 2, 2>"
),
b_block_transfer=BlockTransferDesc(
"S<1, 1, 4, 2>", "S<8, 1, 32, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"
),
c_block_transfer=CBlockTransferDesc("S<0, 1, 2, 3, 4, 5>", 5, 4),
)
a = EmitGemmInstance()
a.emit()
a.emit(gemm)
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl<
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::Default,
128,
128,
32,
8,
2,
4,
2,
1,
S<8, 2>,
S<4, 2>,
S<1, 1, 8, 2>,
S<8, 1, 16, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<1, 1, 2, 2>,
S<8, 1, 16, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 2, 1>,
S<0, 3, 1, 2>,
S<1, 1, 2, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
2>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
bool run_gemm_128_128_32_8_2(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ck::half_t> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ck::tensor_layout::gemm::ColumnMajor{}));
Tensor<ck::half_t> b_k_n(f_host_tensor_descriptor(K, N, StrideB, ck::tensor_layout::gemm::RowMajor{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(b_k_n);
}
Tensor<ck::half_t> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ck::half_t> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ck::half_t) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(ck::half_t) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(ck::half_t) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto b_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto c_element_op = ck::tensor_operation::element_wise::PassThrough{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ck::half_t*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ck::half_t) * M * K + sizeof(ck::half_t) * K * N + sizeof(ck::half_t) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool __attribute__((visibility("default"))) run_gemm_128_128_32_8_2(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm_128_128_32_8_2(problem_size, config);
}
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl<
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::Default,
128,
128,
64,
8,
2,
4,
4,
1,
S<8, 2>,
S<4, 2>,
S<1, 1, 8, 2>,
S<8, 1, 16, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<1, 1, 4, 2>,
S<8, 1, 16, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
4>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
bool run_gemm_128_128_64_8_2(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ck::half_t> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ck::tensor_layout::gemm::ColumnMajor{}));
Tensor<ck::half_t> b_k_n(f_host_tensor_descriptor(K, N, StrideB, ck::tensor_layout::gemm::RowMajor{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(b_k_n);
}
Tensor<ck::half_t> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ck::half_t> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ck::half_t) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(ck::half_t) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(ck::half_t) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto b_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto c_element_op = ck::tensor_operation::element_wise::PassThrough{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ck::half_t*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ck::half_t) * M * K + sizeof(ck::half_t) * K * N + sizeof(ck::half_t) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool __attribute__((visibility("default"))) run_gemm_128_128_64_8_2(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm_128_128_64_8_2(problem_size, config);
}
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl<
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::Default,
128,
32,
128,
8,
2,
2,
4,
1,
S<4, 2>,
S<8, 2>,
S<1, 1, 2, 2>,
S<8, 1, 16, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 2, 1>,
S<0, 3, 1, 2>,
S<1, 1, 2, 2>,
S<1, 1, 8, 2>,
S<8, 1, 16, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
4>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
bool run_gemm_128_32_128_8_2(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ck::half_t> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ck::tensor_layout::gemm::ColumnMajor{}));
Tensor<ck::half_t> b_k_n(f_host_tensor_descriptor(K, N, StrideB, ck::tensor_layout::gemm::RowMajor{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(b_k_n);
}
Tensor<ck::half_t> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ck::half_t> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ck::half_t) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(ck::half_t) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(ck::half_t) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto b_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto c_element_op = ck::tensor_operation::element_wise::PassThrough{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ck::half_t*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ck::half_t) * M * K + sizeof(ck::half_t) * K * N + sizeof(ck::half_t) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool __attribute__((visibility("default"))) run_gemm_128_32_128_8_2(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm_128_32_128_8_2(problem_size, config);
}
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl<
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::Default,
128,
32,
64,
8,
2,
2,
2,
1,
S<4, 2>,
S<8, 2>,
S<1, 1, 2, 2>,
S<8, 1, 16, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 2, 1>,
S<0, 3, 1, 2>,
S<1, 1, 2, 2>,
S<1, 1, 4, 2>,
S<8, 1, 16, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
4>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
bool run_gemm_128_32_64_8_2(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ck::half_t> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ck::tensor_layout::gemm::ColumnMajor{}));
Tensor<ck::half_t> b_k_n(f_host_tensor_descriptor(K, N, StrideB, ck::tensor_layout::gemm::RowMajor{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(b_k_n);
}
Tensor<ck::half_t> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ck::half_t> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ck::half_t) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(ck::half_t) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(ck::half_t) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto b_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto c_element_op = ck::tensor_operation::element_wise::PassThrough{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ck::half_t*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ck::half_t) * M * K + sizeof(ck::half_t) * K * N + sizeof(ck::half_t) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool __attribute__((visibility("default"))) run_gemm_128_32_64_8_2(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm_128_32_64_8_2(problem_size, config);
}
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl<
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::Default,
128,
64,
128,
8,
2,
4,
4,
1,
S<4, 2>,
S<8, 2>,
S<1, 1, 4, 2>,
S<8, 1, 16, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<1, 1, 8, 2>,
S<8, 1, 16, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
4>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
bool run_gemm_128_64_128_8_2(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ck::half_t> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ck::tensor_layout::gemm::ColumnMajor{}));
Tensor<ck::half_t> b_k_n(f_host_tensor_descriptor(K, N, StrideB, ck::tensor_layout::gemm::RowMajor{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(b_k_n);
}
Tensor<ck::half_t> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ck::half_t> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ck::half_t) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(ck::half_t) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(ck::half_t) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto b_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto c_element_op = ck::tensor_operation::element_wise::PassThrough{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ck::half_t*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ck::half_t) * M * K + sizeof(ck::half_t) * K * N + sizeof(ck::half_t) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool __attribute__((visibility("default"))) run_gemm_128_64_128_8_2(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm_128_64_128_8_2(problem_size, config);
}
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl<
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::Default,
128,
64,
32,
8,
2,
2,
2,
1,
S<8, 2>,
S<4, 2>,
S<1, 1, 4, 2>,
S<8, 1, 16, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<1, 1, 2, 2>,
S<8, 1, 16, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 2, 1>,
S<0, 3, 1, 2>,
S<1, 1, 2, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
2>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
bool run_gemm_128_64_32_8_2(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ck::half_t> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ck::tensor_layout::gemm::ColumnMajor{}));
Tensor<ck::half_t> b_k_n(f_host_tensor_descriptor(K, N, StrideB, ck::tensor_layout::gemm::RowMajor{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<ck::half_t>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<ck::half_t>{-1.f, 1.f}(b_k_n);
}
Tensor<ck::half_t> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ck::half_t> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ck::half_t) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(ck::half_t) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(ck::half_t) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto b_element_op = ck::tensor_operation::element_wise::PassThrough{};
auto c_element_op = ck::tensor_operation::element_wise::PassThrough{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ck::half_t*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<ck::half_t*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ck::half_t) * M * K + sizeof(ck::half_t) * K * N + sizeof(ck::half_t) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool __attribute__((visibility("default"))) run_gemm_128_64_32_8_2(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm_128_64_32_8_2(problem_size, config);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment