Commit dbe08e9b authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.4.2

parent b5499578
...@@ -107,15 +107,21 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -107,15 +107,21 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
sizeof(bias_data))); sizeof(bias_data)));
if (enable_auxiliary && activation != "none") { if (enable_auxiliary && activation != "none") {
size_t reserve_space_size = 0; // Note (Ming Huang): The initialization of ReseveSpace is happened in the
// dev_ctx.Alloc. Therefore, we set real date type up here.
if (activation == "relu") { if (activation == "relu") {
// Count in bits. paddle::experimental::DataType rs_type =
reserve_space_size = phi::product(out->dims()) / 8; paddle::experimental::DataType::BOOL;
size_t reserve_space_size =
phi::product(reserve_space->dims()) * SizeOf(rs_type);
dev_ctx.Alloc(reserve_space, rs_type, reserve_space_size);
} else { } else {
reserve_space_size = phi::product(out->dims()) * sizeof(T); size_t reserve_space_size =
phi::product(reserve_space->dims()) * sizeof(T);
dev_ctx.Alloc<T>(reserve_space, reserve_space_size);
} }
dev_ctx.Alloc(reserve_space, out->type(), reserve_space_size);
void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>()); void* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
...@@ -185,7 +191,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -185,7 +191,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
stream, stream,
workspace->ptr(), workspace->ptr(),
workspace_size); workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle, platform::dynload::cublasLtMatmul(lt_handle,
operation_desc, operation_desc,
...@@ -478,7 +483,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -478,7 +483,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
sizeof(epiloque_func_for_dx))); sizeof(epiloque_func_for_dx)));
if (activation_grad != "none") { if (activation_grad != "none") {
auto* aux_data = reserve_space->data<T>(); auto* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, dx_operation_desc,
......
This diff is collapsed.
...@@ -39,14 +39,23 @@ cc_test( ...@@ -39,14 +39,23 @@ cc_test(
SRCS test.cc SRCS test.cc
DEPS jit_kernel_helper) DEPS jit_kernel_helper)
if(NOT WIN32) if(NOT WIN32)
cc_binary( set(cuda_less12_and_gcc_greater12 false)
jit_kernel_benchmark if(DEFINED CMAKE_CUDA_COMPILER_VERSION)
SRCS if(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0
benchmark.cc AND ${CMAKE_CXX_COMPILER_VERSION} VERSION_GREATER 12.0)
DEPS set(cuda_less12_and_gcc_greater12 true)
jit_kernel_helper endif()
device_tracer endif()
tensor) if(NOT cuda_less12_and_gcc_greater12)
cc_binary(
jit_kernel_benchmark
SRCS
benchmark.cc
DEPS
jit_kernel_helper
device_tracer
tensor)
endif()
endif() endif()
if(WITH_TESTING AND TEST jit_kernel_test) if(WITH_TESTING AND TEST jit_kernel_test)
set_tests_properties(jit_kernel_test PROPERTIES TIMEOUT 120) set_tests_properties(jit_kernel_test PROPERTIES TIMEOUT 120)
......
...@@ -214,10 +214,7 @@ class MatMulMKLDNNHandler ...@@ -214,10 +214,7 @@ class MatMulMKLDNNHandler
} }
astream.wait(); astream.wait();
auto format = out->set_mem_desc(dst_memory_p->get_desc().reshape(out->dims()));
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
} }
std::shared_ptr<dnnl::memory> AcquireDstMemory( std::shared_ptr<dnnl::memory> AcquireDstMemory(
...@@ -651,10 +648,18 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, ...@@ -651,10 +648,18 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
auto &astream = MKLDNNDeviceContext::tls().get_stream(); auto &astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
auto format =
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw); // TODO(jczaja): Explain why int8 format of dst is ABCD and do not need
out->set_format(format); // permute
out->set_layout(DataLayout::kMKLDNN); if (IsOutputFused(ctx) && !IsInt8<T_out>()) {
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_Out");
auto permuted_md = dst_memory_p->get_desc().permute_axes(axis);
out->set_mem_desc(
permuted_md.reshape(phi::vectorize<int64_t>(out->dims())));
} else {
out->set_mem_desc(
dst_memory_p->get_desc().reshape(phi::vectorize<int64_t>(out->dims())));
}
} }
template <typename T> template <typename T>
...@@ -836,8 +841,7 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -836,8 +841,7 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
reduction_p->execute(astream, reduction_args); reduction_p->execute(astream, reduction_args);
astream.wait(); astream.wait();
dx->set_format(paddle::platform::GetMKLDNNFormat( dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
dst_memory_p->get_desc().reshape(squeezed_dims)));
} }
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims, std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims,
...@@ -1119,9 +1123,8 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad( ...@@ -1119,9 +1123,8 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); out->set_mem_desc(
out->set_format(platform::GetMKLDNNFormat( dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))));
} }
template <typename T> template <typename T>
...@@ -1184,13 +1187,13 @@ void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext &ctx) const { ...@@ -1184,13 +1187,13 @@ void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext &ctx) const {
if (dx) { if (dx) {
if (dx_dims != x.dims()) { if (dx_dims != x.dims()) {
dx->Resize(dx_dims); dx->Resize(dx_dims);
dx->set_format(x.format()); dx->set_mem_desc(x.mem_desc());
} }
} }
if (dy) { if (dy) {
if (dy_dims != y.dims()) { if (dy_dims != y.dims()) {
dy->Resize(dy_dims); dy->Resize(dy_dims);
dy->set_format(y.format()); dy->set_mem_desc(y.mem_desc());
} }
} }
} }
......
...@@ -221,7 +221,7 @@ class MulPrimitiveFactory { ...@@ -221,7 +221,7 @@ class MulPrimitiveFactory {
to_void_cast<T>(x_tmp.data<T>())); to_void_cast<T>(x_tmp.data<T>()));
x_tmp.Resize(data->dims()); x_tmp.Resize(data->dims());
x_tmp.set_format(platform::GetMKLDNNFormat(dst_mdesc)); x_tmp.set_mem_desc(dst_mdesc);
data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims); data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims);
} else { } else {
data_matrix = framework::ReshapeToMatrix(*data, num_col_dims); data_matrix = framework::ReshapeToMatrix(*data, num_col_dims);
...@@ -235,11 +235,7 @@ class MulPrimitiveFactory { ...@@ -235,11 +235,7 @@ class MulPrimitiveFactory {
const Tensor *in) { const Tensor *in) {
x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>())); x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>()));
output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace())); output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
out->set_mem_desc(output_->get_desc());
if (out->format() == MKLDNNMemoryFormat::undef) {
auto output_format = platform::GetMKLDNNFormat(*output_);
out->set_format((MKLDNNMemoryFormat)output_format);
}
} }
template <typename T> template <typename T>
...@@ -272,7 +268,7 @@ class MulPrimitiveFactory { ...@@ -272,7 +268,7 @@ class MulPrimitiveFactory {
auto buffer_size = dst_desc.get_size(); auto buffer_size = dst_desc.get_size();
OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size); OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size);
output->set_format(paddle::platform::GetMKLDNNFormat(dst_desc)); output->set_mem_desc(dst_desc);
return memory(dst_desc, engine_, to_void_cast<OT>(output_data)); return memory(dst_desc, engine_, to_void_cast<OT>(output_data));
} }
...@@ -392,9 +388,10 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> { ...@@ -392,9 +388,10 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
if (out_dims.size() != 2) { if (out_dims.size() != 2) {
out->Resize(out_dims); out->Resize(out_dims);
} }
out->set_layout(DataLayout::kMKLDNN);
out->set_format(platform::MKLDNNFormatForSize(out_dims.size(), auto in_md = dnnl::memory::desc(*dnnl_primitive_desc_query_md(
MKLDNNMemoryFormat::nchw)); mul.get_primitive_desc(), dnnl_query_dst_md, 0));
out->set_mem_desc(in_md.reshape(phi::vectorize<int64_t>(out->dims())));
} }
}; };
...@@ -442,10 +439,11 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> { ...@@ -442,10 +439,11 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); // This kernel is flattening dims so then we need to unflattened version
// plain output formats are enforced inside handler // that should be set in out reshape require plain layout, but
out->set_format(platform::MKLDNNFormatForSize( // MatmulV2MKLDNNHanlder enforces one so it should work
out->dims().size(), dnnl::memory::format_tag::nchw)); out->set_mem_desc(
dst_memory_p->get_desc().reshape(phi::vectorize<int64_t>(out->dims())));
} }
private: private:
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment