"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "5753c5bbaf762be1fae091c16d5f08016d32efd1"
Unverified Commit 9dddb36d authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[TE/JAX] Custom call with FFI - lowering all attributes with bind all (#1289)



* lowering a dict of attrs

* improve err message with line and func info

* implement a product() for ffi dimensions

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent c036765b
...@@ -401,14 +401,11 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -401,14 +401,11 @@ class FusedAttnFwdPrimitive(BasePrimitive):
bias_heads=bias_heads, bias_heads=bias_heads,
head_dim=head_dim, head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq, max_segments_per_seq=config.max_segments_per_seq,
wkspace_size=wkspace_aval.size,
scaling_factor=float(config.scaling_factor), scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability), dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type), bias_type=int(config.attn_bias_type),
mask_type=int(config.attn_mask_type), mask_type=int(config.attn_mask_type),
qkv_layout=int(config.qkv_layout), qkv_layout=int(config.qkv_layout),
dtype=int(jax_dtype_to_te_dtype(q_aval.dtype)),
wkspace_dtype=int(jax_dtype_to_te_dtype(wkspace_aval.dtype)),
is_training=config.is_training, is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0], window_size_left=config.window_size[0],
......
...@@ -110,7 +110,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type outp ...@@ -110,7 +110,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type outp
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto input_dims = input_buf.dimensions(); auto input_dims = input_buf.dimensions();
auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>()); auto m = product(input_dims, 0, input_dims.size() - 2);
auto n = input_dims.back(); auto n = input_dims.back();
auto act_len = input_dims.end()[-2]; auto act_len = input_dims.end()[-2];
auto act_type = static_cast<NVTE_Activation_Type>(act_enum); auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
...@@ -175,7 +175,7 @@ Error_Type ActLuFP8FFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type a ...@@ -175,7 +175,7 @@ Error_Type ActLuFP8FFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type a
} }
auto input_dims = input_buf.dimensions(); auto input_dims = input_buf.dimensions();
auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>()); auto m = product(input_dims, 0, input_dims.size() - 2);
auto n = input_dims.back(); auto n = input_dims.back();
auto act_len = input_dims.end()[-2]; auto act_len = input_dims.end()[-2];
auto act_type = static_cast<NVTE_Activation_Type>(act_enum); auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
...@@ -264,8 +264,7 @@ Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act ...@@ -264,8 +264,7 @@ Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto act_input_dims = act_input_buf.dimensions(); auto act_input_dims = act_input_buf.dimensions();
auto m = auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
std::accumulate(act_input_dims.begin(), act_input_dims.end() - 2, 1, std::multiplies<>());
auto n = act_input_dims.back(); auto n = act_input_dims.back();
auto act_len = act_input_dims.end()[-2]; auto act_len = act_input_dims.end()[-2];
......
...@@ -329,36 +329,55 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -329,36 +329,55 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right); descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
} }
Error_Type FusedAttnForwardFFI( Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf, Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type bias_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, Buffer_Type seed_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
Result_Type output_buf, Result_Type softmax_aux_buf, Result_Type rng_state_buf, Buffer_Type seed_buf, Result_Type output_buf,
Result_Type workspace_buf, int64_t input_batch_, int64_t bias_batch_, int64_t q_max_seqlen_, Result_Type softmax_aux_buf, Result_Type rng_state_buf,
int64_t kv_max_seqlen_, int64_t attn_heads_, int64_t num_gqa_groups_, int64_t bias_heads_, Result_Type workspace_buf, Dictionary attrs) {
int64_t head_dim_, int64_t max_segments_per_seq_, int64_t wkspace_size_, double scaling_factor_, /* Descriptor data type conversion */
double dropout_probability_, int64_t bias_type_, int64_t mask_type_, int64_t qkv_layout_, size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch");
int64_t dtype_, int64_t wkspace_dtype_, bool is_training, bool deterministic, size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch");
int64_t window_size_left, int64_t window_size_right) { size_t q_max_seqlen = get_attr_value<int64_t>(attrs, "q_max_seqlen");
NVTE_QKV_Layout qkv_layout = static_cast<NVTE_QKV_Layout>(qkv_layout_); size_t kv_max_seqlen = get_attr_value<int64_t>(attrs, "kv_max_seqlen");
size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads");
size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups");
size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads");
size_t head_dim = get_attr_value<int64_t>(attrs, "head_dim");
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq");
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left");
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right");
float scaling_factor = get_attr_value<double>(attrs, "scaling_factor");
float dropout_probability = get_attr_value<double>(attrs, "dropout_probability");
NVTE_Bias_Type bias_type =
static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type"));
NVTE_Mask_Type mask_type =
static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type"));
NVTE_QKV_Layout qkv_layout =
static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout"));
bool is_training = get_attr_value<bool>(attrs, "is_training");
bool deterministic = get_attr_value<bool>(attrs, "deterministic");
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
size_t wkspace_size = product(workspace_buf->dimensions());
DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type());
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
FusedAttnForwardImpl( FusedAttnForwardImpl(
stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(), bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(),
is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(), is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(),
output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(),
workspace_buf->untyped_data(), static_cast<size_t>(input_batch_), workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
static_cast<size_t>(bias_batch_), static_cast<size_t>(q_max_seqlen_), attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size,
static_cast<size_t>(kv_max_seqlen_), static_cast<size_t>(attn_heads_), scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype,
static_cast<size_t>(num_gqa_groups_), static_cast<size_t>(bias_heads_), is_training, deterministic, window_size_left, window_size_right);
static_cast<size_t>(head_dim_), static_cast<size_t>(max_segments_per_seq_),
static_cast<size_t>(wkspace_size_), static_cast<float>(scaling_factor_),
static_cast<float>(dropout_probability_), static_cast<NVTE_Bias_Type>(bias_type_),
static_cast<NVTE_Mask_Type>(mask_type_), static_cast<NVTE_QKV_Layout>(qkv_layout_),
static_cast<DType>(dtype_), static_cast<DType>(wkspace_dtype_), is_training, deterministic,
window_size_left, window_size_right);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -379,27 +398,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI, ...@@ -379,27 +398,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
.Ret<Buffer_Type>() // softmax_aux .Ret<Buffer_Type>() // softmax_aux
.Ret<Buffer_Type>() // rng_state .Ret<Buffer_Type>() // rng_state
.Ret<Buffer_Type>() // workspace .Ret<Buffer_Type>() // workspace
.Attr<int64_t>("input_batch") .Attrs(),
.Attr<int64_t>("bias_batch")
.Attr<int64_t>("q_max_seqlen")
.Attr<int64_t>("kv_max_seqlen")
.Attr<int64_t>("attn_heads")
.Attr<int64_t>("num_gqa_groups")
.Attr<int64_t>("bias_heads")
.Attr<int64_t>("head_dim")
.Attr<int64_t>("max_segments_per_seq")
.Attr<int64_t>("wkspace_size")
.Attr<double>("scaling_factor")
.Attr<double>("dropout_probability")
.Attr<int64_t>("bias_type")
.Attr<int64_t>("mask_type")
.Attr<int64_t>("qkv_layout")
.Attr<int64_t>("dtype")
.Attr<int64_t>("wkspace_dtype")
.Attr<bool>("is_training")
.Attr<bool>("deterministic")
.Attr<int64_t>("window_size_left")
.Attr<int64_t>("window_size_right"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
...@@ -608,7 +607,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -608,7 +607,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dqkv = buffers[12]; auto dqkv = buffers[12];
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
if (is_ragged) { if (is_ragged) {
cudaMemsetAsync(dqkv, 0, product(qkv_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dqkv, 0, transformer_engine::product(qkv_shape) * typeToSize(dtype), stream);
} }
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
...@@ -630,8 +629,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -630,8 +629,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dkv = buffers[13]; auto dkv = buffers[13];
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
if (is_ragged) { if (is_ragged) {
cudaMemsetAsync(dq, 0, product(q_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dkv, 0, product(kv_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dkv, 0, transformer_engine::product(kv_shape) * typeToSize(dtype), stream);
} }
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
...@@ -659,9 +658,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -659,9 +658,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dv = buffers[14]; auto dv = buffers[14];
auto dv_tensor = TensorWrapper(dv, v_shape, dtype); auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
if (is_ragged) { if (is_ragged) {
cudaMemsetAsync(dq, 0, product(q_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, product(k_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dk, 0, transformer_engine::product(k_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dv, 0, product(v_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dv, 0, transformer_engine::product(v_shape) * typeToSize(dtype), stream);
} }
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(), doutput_tensor.data(),
......
...@@ -7,8 +7,6 @@ ...@@ -7,8 +7,6 @@
#include <iostream> #include <iostream>
#include "common/util/logging.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <numeric> #include <numeric>
#include "common/util/logging.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -17,10 +19,63 @@ using Result_Type = xla::ffi::Result<xla::ffi::AnyBuffer>; ...@@ -17,10 +19,63 @@ using Result_Type = xla::ffi::Result<xla::ffi::AnyBuffer>;
using Error_Type = xla::ffi::Error; using Error_Type = xla::ffi::Error;
using FFI = xla::ffi::Ffi; using FFI = xla::ffi::Ffi;
using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>; using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>;
using Dictionary = xla::ffi::Dictionary;
constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible}; constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible};
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type); DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type);
Error_Type ffi_with_cuda_error_check(); Error_Type ffi_with_cuda_error_check();
// source_location is not available in C++17, so we implement it ourselves
#if defined(__GNUC__) || defined(__clang__)
#define CURRENT_FILE __builtin_FILE()
#define CURRENT_LINE __builtin_LINE()
#define CURRENT_FUNCTION __builtin_FUNCTION()
#else
#define CURRENT_FILE __FILE__
#define CURRENT_LINE __LINE__
#define CURRENT_FUNCTION __func__
#endif
class source_location {
public:
static source_location current(const char* file = CURRENT_FILE, int line = CURRENT_LINE,
const char* function = CURRENT_FUNCTION) {
return source_location(file, line, function);
}
constexpr const char* file_name() const { return file_; }
constexpr int line() const { return line_; }
constexpr const char* function_name() const { return function_; }
private:
constexpr source_location(const char* file, int line, const char* function)
: file_(file), line_(line), function_(function) {}
const char* file_;
int line_;
const char* function_;
};
template <typename T>
T get_attr_value(Dictionary& attrs, std::string attr_name,
const source_location& loc = source_location::current()) {
auto attr = attrs.get<T>(attr_name);
if (attr.has_error()) {
NVTE_ERROR("Failure in getting attribute value of '", attr_name, "'\n",
"Called from: ", loc.file_name(), ":", loc.line(), "\n",
"In function: ", loc.function_name(), "\n",
"Please ensure the attribute name and datatype match between C++ and Python APIs.");
}
return attr.value();
}
inline size_t product(const xla::ffi::Span<const int64_t>& data, size_t start_idx = 0,
size_t end_idx = 0) {
end_idx = (end_idx == 0) ? data.size() : end_idx;
return std::accumulate(data.begin() + start_idx, data.begin() + end_idx, size_t(1),
std::multiplies<size_t>());
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -264,19 +264,13 @@ Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer ...@@ -264,19 +264,13 @@ Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer
NVTE_CHECK(amax_out == amax, NVTE_CHECK(amax_out == amax,
"amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive"); "amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive");
auto x_dims = x_buf.dimensions(); auto x_size = product(x_buf.dimensions());
auto gamma_dims = gamma_buf.dimensions(); auto gamma_size = product(gamma_buf.dimensions());
auto x_size = std::accumulate(x_dims.begin(), x_dims.end(), 1, std::multiplies<>());
auto gamma_size = std::accumulate(gamma_dims.begin(), gamma_dims.end(), 1, std::multiplies<>());
auto hidden_size = gamma_size; auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size; auto batch_size = x_size / gamma_size;
auto wkspace_dims = wkspace_buf->dimensions(); auto wkspace_size = product(wkspace_buf->dimensions());
auto barrier_dims = barrier_buf->dimensions(); auto barrier_size = product(barrier_buf->dimensions());
auto wkspace_size =
std::accumulate(wkspace_dims.begin(), wkspace_dims.end(), 1, std::multiplies<>());
auto barrier_size =
std::accumulate(barrier_dims.begin(), barrier_dims.end(), 1, std::multiplies<>());
float eps = static_cast<float>(eps_); float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_); int sm_margin = static_cast<int>(sm_margin_);
...@@ -408,19 +402,13 @@ Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_ ...@@ -408,19 +402,13 @@ Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_
auto *dgamma_part = dgamma_part_buf->untyped_data(); auto *dgamma_part = dgamma_part_buf->untyped_data();
auto *dbeta_part = dbeta_part_buf->untyped_data(); auto *dbeta_part = dbeta_part_buf->untyped_data();
auto x_dims = x_buf.dimensions(); auto x_size = product(x_buf.dimensions());
auto gamma_dims = gamma_buf.dimensions(); auto gamma_size = product(gamma_buf.dimensions());
auto x_size = std::accumulate(x_dims.begin(), x_dims.end(), 1, std::multiplies<>());
auto gamma_size = std::accumulate(gamma_dims.begin(), gamma_dims.end(), 1, std::multiplies<>());
auto hidden_size = gamma_size; auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size; auto batch_size = x_size / gamma_size;
auto wkspace_dims = wkspace_buf->dimensions(); auto wkspace_size = product(wkspace_buf->dimensions());
auto barrier_dims = barrier_buf->dimensions(); auto barrier_size = product(barrier_buf->dimensions());
auto wkspace_size =
std::accumulate(wkspace_dims.begin(), wkspace_dims.end(), 1, std::multiplies<>());
auto barrier_size =
std::accumulate(barrier_dims.begin(), barrier_dims.end(), 1, std::multiplies<>());
auto dgamma_part_dims = dgamma_part_buf->dimensions(); auto dgamma_part_dims = dgamma_part_buf->dimensions();
auto dbeta_part_dims = dbeta_part_buf->dimensions(); auto dbeta_part_dims = dbeta_part_buf->dimensions();
......
...@@ -46,10 +46,9 @@ Error_Type TransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type ...@@ -46,10 +46,9 @@ Error_Type TransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type
auto input_dims = input_buf.dimensions(); auto input_dims = input_buf.dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size(); if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = std::accumulate(input_dims.begin(), input_dims.begin() + transpose_axis, 1, auto m = product(input_dims, 0, transpose_axis);
std::multiplies<>()); auto n = product(input_dims, transpose_axis, input_dims.size());
auto n = std::accumulate(input_dims.begin() + transpose_axis, input_dims.end(), 1,
std::multiplies<>());
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{n, m}; auto output_shape = std::vector<size_t>{n, m};
...@@ -124,10 +123,8 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -124,10 +123,8 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto input_dims = input_buf.dimensions(); auto input_dims = input_buf.dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size(); if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = std::accumulate(input_dims.begin(), input_dims.begin() + transpose_axis, 1, auto m = product(input_dims, 0, transpose_axis);
std::multiplies<>()); auto n = product(input_dims, transpose_axis, input_dims.size());
auto n = std::accumulate(input_dims.begin() + transpose_axis, input_dims.end(), 1,
std::multiplies<>());
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto input_trans_shape = std::vector<size_t>{n, m}; auto input_trans_shape = std::vector<size_t>{n, m};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment