".github/vscode:/vscode.git/clone" did not exist on "7fa3e5b0f6a593d06e65f5b40ccd46acfafcfeb1"
Commit dbe08e9b authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.4.2

parent b5499578
......@@ -107,15 +107,21 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
sizeof(bias_data)));
if (enable_auxiliary && activation != "none") {
size_t reserve_space_size = 0;
// Note (Ming Huang): The initialization of ReseveSpace is happened in the
// dev_ctx.Alloc. Therefore, we set real date type up here.
if (activation == "relu") {
// Count in bits.
reserve_space_size = phi::product(out->dims()) / 8;
paddle::experimental::DataType rs_type =
paddle::experimental::DataType::BOOL;
size_t reserve_space_size =
phi::product(reserve_space->dims()) * SizeOf(rs_type);
dev_ctx.Alloc(reserve_space, rs_type, reserve_space_size);
} else {
reserve_space_size = phi::product(out->dims()) * sizeof(T);
size_t reserve_space_size =
phi::product(reserve_space->dims()) * sizeof(T);
dev_ctx.Alloc<T>(reserve_space, reserve_space_size);
}
dev_ctx.Alloc(reserve_space, out->type(), reserve_space_size);
void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>());
void* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
......@@ -185,7 +191,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
stream,
workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle,
operation_desc,
......@@ -478,7 +483,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
sizeof(epiloque_func_for_dx)));
if (activation_grad != "none") {
auto* aux_data = reserve_space->data<T>();
auto* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
......
This diff is collapsed.
......@@ -39,14 +39,23 @@ cc_test(
SRCS test.cc
DEPS jit_kernel_helper)
if(NOT WIN32)
cc_binary(
jit_kernel_benchmark
SRCS
benchmark.cc
DEPS
jit_kernel_helper
device_tracer
tensor)
set(cuda_less12_and_gcc_greater12 false)
if(DEFINED CMAKE_CUDA_COMPILER_VERSION)
if(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0
AND ${CMAKE_CXX_COMPILER_VERSION} VERSION_GREATER 12.0)
set(cuda_less12_and_gcc_greater12 true)
endif()
endif()
if(NOT cuda_less12_and_gcc_greater12)
cc_binary(
jit_kernel_benchmark
SRCS
benchmark.cc
DEPS
jit_kernel_helper
device_tracer
tensor)
endif()
endif()
if(WITH_TESTING AND TEST jit_kernel_test)
set_tests_properties(jit_kernel_test PROPERTIES TIMEOUT 120)
......
......@@ -214,10 +214,7 @@ class MatMulMKLDNNHandler
}
astream.wait();
auto format =
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
out->set_mem_desc(dst_memory_p->get_desc().reshape(out->dims()));
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(
......@@ -651,10 +648,18 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
auto &astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
auto format =
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
// TODO(jczaja): Explain why int8 format of dst is ABCD and do not need
// permute
if (IsOutputFused(ctx) && !IsInt8<T_out>()) {
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_Out");
auto permuted_md = dst_memory_p->get_desc().permute_axes(axis);
out->set_mem_desc(
permuted_md.reshape(phi::vectorize<int64_t>(out->dims())));
} else {
out->set_mem_desc(
dst_memory_p->get_desc().reshape(phi::vectorize<int64_t>(out->dims())));
}
}
template <typename T>
......@@ -836,8 +841,7 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
reduction_p->execute(astream, reduction_args);
astream.wait();
dx->set_format(paddle::platform::GetMKLDNNFormat(
dst_memory_p->get_desc().reshape(squeezed_dims)));
dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
}
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims,
......@@ -1119,9 +1123,8 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))));
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
template <typename T>
......@@ -1184,13 +1187,13 @@ void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext &ctx) const {
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
dx->set_format(x.format());
dx->set_mem_desc(x.mem_desc());
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
dy->set_format(y.format());
dy->set_mem_desc(y.mem_desc());
}
}
}
......
......@@ -221,7 +221,7 @@ class MulPrimitiveFactory {
to_void_cast<T>(x_tmp.data<T>()));
x_tmp.Resize(data->dims());
x_tmp.set_format(platform::GetMKLDNNFormat(dst_mdesc));
x_tmp.set_mem_desc(dst_mdesc);
data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims);
} else {
data_matrix = framework::ReshapeToMatrix(*data, num_col_dims);
......@@ -235,11 +235,7 @@ class MulPrimitiveFactory {
const Tensor *in) {
x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>()));
output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
if (out->format() == MKLDNNMemoryFormat::undef) {
auto output_format = platform::GetMKLDNNFormat(*output_);
out->set_format((MKLDNNMemoryFormat)output_format);
}
out->set_mem_desc(output_->get_desc());
}
template <typename T>
......@@ -272,7 +268,7 @@ class MulPrimitiveFactory {
auto buffer_size = dst_desc.get_size();
OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size);
output->set_format(paddle::platform::GetMKLDNNFormat(dst_desc));
output->set_mem_desc(dst_desc);
return memory(dst_desc, engine_, to_void_cast<OT>(output_data));
}
......@@ -392,9 +388,10 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
if (out_dims.size() != 2) {
out->Resize(out_dims);
}
out->set_layout(DataLayout::kMKLDNN);
out->set_format(platform::MKLDNNFormatForSize(out_dims.size(),
MKLDNNMemoryFormat::nchw));
auto in_md = dnnl::memory::desc(*dnnl_primitive_desc_query_md(
mul.get_primitive_desc(), dnnl_query_dst_md, 0));
out->set_mem_desc(in_md.reshape(phi::vectorize<int64_t>(out->dims())));
}
};
......@@ -442,10 +439,11 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
// plain output formats are enforced inside handler
out->set_format(platform::MKLDNNFormatForSize(
out->dims().size(), dnnl::memory::format_tag::nchw));
// This kernel is flattening dims so then we need to unflattened version
// that should be set in out reshape require plain layout, but
// MatmulV2MKLDNNHanlder enforces one so it should work
out->set_mem_desc(
dst_memory_p->get_desc().reshape(phi::vectorize<int64_t>(out->dims())));
}
private:
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment