"docs/vscode:/vscode.git/clone" did not exist on "6184d8a43357f3397c0848b5d0b716cf389d1f30"
Unverified Commit a55f4b7f authored by liuduanhui's avatar liuduanhui Committed by GitHub
Browse files

[Enhancement] Replace the implementation of three_nn_forward with mlu-ops (#2719)

parent f946a933
This diff is collapsed.
...@@ -9,84 +9,47 @@ ...@@ -9,84 +9,47 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/ *************************************************************************/
#include "pytorch_device_registry.hpp" #include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
void KernelThreeNNForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t data_type,
const void *unknown, const void *known, void *dist2,
int *idx, const int b, const int n, const int m);
void ThreeNNMLUKernelLauncher(int b, int n, int m, const Tensor unknown, void ThreeNNMLUKernelLauncher(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx) { const Tensor known, Tensor dist2, Tensor idx) {
// Check dtype. auto unknown_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
TORCH_CHECK( unknown, unknown.suggest_memory_format());
unknown.scalar_type() == at::kFloat || unknown.scalar_type() == at::kHalf, auto known_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
"unknown type should be Float or Half, got ", unknown.scalar_type(), "."); known, known.suggest_memory_format());
TORCH_CHECK(unknown.scalar_type() == known.scalar_type(), auto dist2_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
"known should have the same type as unknown."); dist2, dist2.suggest_memory_format());
TORCH_CHECK(unknown.scalar_type() == dist2.scalar_type(), auto idx_contiguous =
"dist2 should have the same type as unknown."); torch_mlu::cnnl::ops::cnnl_contiguous(idx, idx.suggest_memory_format());
TORCH_CHECK(idx.scalar_type() == at::kInt, "idx type should be Int.");
MluOpTensorDescriptor unknown_desc, known_desc, dist2_desc, idx_desc;
// Check shape. unknown_desc.set(unknown_contiguous);
TORCH_CHECK(unknown.dim() == 3, "unknown should be 3d tensor, got ", known_desc.set(known_contiguous);
unknown.dim(), "D."); dist2_desc.set(dist2_contiguous);
TORCH_CHECK(known.dim() == 3, "known should be 3d tensor, got ", known.dim(), idx_desc.set(idx_contiguous);
"D.");
TORCH_CHECK(unknown.size(0) == known.size(0), auto handle = mluOpGetCurrentHandle();
"known.dim0 should be equal to unknown.dim0, got ", known.size(0), size_t workspace_size = 0;
"."); mluOpGetThreeNNForwardWorkspaceSize(handle, known_desc.desc(),
TORCH_CHECK(unknown.size(2) == 3, "unknown dim2 should be 3, got ", &workspace_size);
unknown.size(2), "."); auto known_workspace =
TORCH_CHECK(known.size(2) == 3, "known dim2 should be 3, got ", known.size(2), at::empty(workspace_size, known.options().dtype(at::kByte));
".");
auto unknown_impl = torch_mlu::getMluTensorImpl(unknown_contiguous);
// zero element check auto known_impl = torch_mlu::getMluTensorImpl(known_contiguous);
TORCH_CHECK(unknown.numel() > 0, auto dist2_impl = torch_mlu::getMluTensorImpl(dist2_contiguous);
"unknown.numel should greater than zero, got ", unknown.numel(), auto idx_impl = torch_mlu::getMluTensorImpl(idx_contiguous);
"."); auto workspace_impl = torch_mlu::getMluTensorImpl(known_workspace);
if (known.numel() == 0) {
// return if known zero element
return;
}
// large tensor check
const size_t max_input_num = 2147483648; // 2^31, 2G num
TORCH_CHECK(unknown.numel() < max_input_num,
"unknown.numel() should be less than 2147483648, got ",
unknown.numel(), ".");
TORCH_CHECK(known.numel() < max_input_num,
"known.numel() should be less than 2147483648, got ",
known.numel(), ".");
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors
auto unknown_impl = torch_mlu::getMluTensorImpl(unknown);
auto unknown_ptr = unknown_impl->cnnlMalloc(); auto unknown_ptr = unknown_impl->cnnlMalloc();
auto known_t = known.permute({0, 2, 1}).contiguous();
auto known_impl = torch_mlu::getMluTensorImpl(known_t);
auto known_ptr = known_impl->cnnlMalloc(); auto known_ptr = known_impl->cnnlMalloc();
auto dist2_impl = torch_mlu::getMluTensorImpl(dist2);
auto dist2_ptr = dist2_impl->cnnlMalloc(); auto dist2_ptr = dist2_impl->cnnlMalloc();
auto idx_impl = torch_mlu::getMluTensorImpl(idx);
auto idx_ptr = idx_impl->cnnlMalloc(); auto idx_ptr = idx_impl->cnnlMalloc();
auto workspace_ptr = workspace_impl->cnnlMalloc();
cnrtJobType_t k_type = CNRT_FUNC_TYPE_UNION1; mluOpThreeNNForward(handle, unknown_desc.desc(), unknown_ptr,
cnrtDim3_t k_dim; known_desc.desc(), known_ptr, workspace_ptr,
k_dim.x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); workspace_size, dist2_desc.desc(), dist2_ptr,
k_dim.y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); idx_desc.desc(), idx_ptr);
k_dim.z = 1;
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(unknown.dtype());
// launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelThreeNNForward<<<" << k_dim.x << ", "
<< k_dim.y << ", " << k_dim.z << ">>>.";
KernelThreeNNForward(k_dim, k_type, queue, data_type, unknown_ptr, known_ptr,
dist2_ptr, (int *)idx_ptr, b, n, m);
} }
void three_nn_forward_mlu(int b, int n, int m, const Tensor unknown, void three_nn_forward_mlu(int b, int n, int m, const Tensor unknown,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment