/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \brief Replace copy from global to shared with async copy * \file inject_ptx_async_copy.cc */ #include #include #include #include #include #include #include #include "storage_access.h" #include "tir/ir/buffer_common.h" #include "tvm/tir/stmt.h" namespace tvm { namespace tl { using namespace tir; class PTXAsyncCopyInjector : public StmtMutator { public: Stmt VisitStmt_(const AttrStmtNode *attr) { if (attr->attr_key == tir::attr::async_scope) { ICHECK(in_async == false) << "Nested async scopes not supported"; in_async = true; auto body = this->VisitStmt(attr->body); in_async = false; return body; } return StmtMutator::VisitStmt_(attr); } Stmt InjectPTX(const BufferLoadNode *load, const BufferStoreNode *store, bool predicated = false, const PrimExpr &predicate_value = PrimExpr()) { if (load->buffer.scope() == "global") { ICHECK(load->indices.size() == 1 && store->indices.size() == 1); ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()) << load->indices[0] << " vs. " << store->indices[0] << " with lanes " << load->indices[0]->dtype.lanes() << " vs. " << store->indices[0]->dtype.lanes(); const int indices_lanes = load->indices[0]->dtype.lanes(); const int bytes = indices_lanes * load->buffer->dtype.bytes(); if (bytes == 4 || bytes == 8 || bytes == 16) { auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) << "Both store and load buffer should have a pointer type " "annotation."; int index_factor = 1; if (dst_elem_type.value() != src_elem_type.value()) { // The only case where src and dst have different dtypes is when the // dst shared memory is a byte buffer generated by merging dynamic // shared memory. ICHECK(store->buffer.scope() == "shared.dyn" || store->buffer.scope() == "shared"); ICHECK(dst_elem_type.value() == DataType::UInt(8)); // BufferStore/Load have the "pointer reinterpret" semantics according // to their "value" dtype. Their "indices" are supposed to be applied // after such pointer cast, for example: // ((*float16)(byte_buffer))[buffer->indices] = fp16_value; To replace // BufferStore/Load with cp.async, we need to multiply the store index // by the byte size of the "value" dtype, to get the correct offset // into the byte buffer. index_factor = src_elem_type->bytes(); } if (indices_lanes == 1) { auto src_offset = load->indices[0]; auto dst_offset = store->indices[0]; Array args = { store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)}; // use arguments size to indicate whether or not to use predicated // cp.async if (predicated) { args.push_back(predicate_value); } return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args)); } // Predicated load don't support vectorized indexing. if (!predicated) { // Only some vectorized indexing patterns are supported for now. auto src_offset = [=]() -> PrimExpr { if (load->indices[0]->IsInstance()) { return load->indices[0].as()->base; } return PrimExpr(); }(); auto dst_offset = [=]() -> PrimExpr { if (store->indices[0].as()) { return store->indices[0].as()->base; } else if (store->indices[0].as()) { // The case where the dst buffer is a byte buffer generated by // merging dynamic shared memory. A_shared.dyn[(ramp(...), 1, 8) + // x8(17408))] = A_global[ramp(...),1, 8)] auto *add = store->indices[0].as(); if (!add->a->IsInstance()) return PrimExpr(); if (!add->b->IsInstance()) return PrimExpr(); return tir::Add(add->a.as()->base, add->b.as()->value); } return PrimExpr(); }(); if (src_offset.defined() && dst_offset.defined()) { return Evaluate(Call( store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)})); } } else { // Only some vectorized indexing patterns are supported for now. auto src_offset = [=]() -> PrimExpr { if (load->indices[0]->IsInstance()) { return load->indices[0].as()->base; } return PrimExpr(); }(); auto dst_offset = [=]() -> PrimExpr { if (store->indices[0].as()) { return store->indices[0].as()->base; } else if (store->indices[0].as()) { // The case where the dst buffer is a byte buffer generated by // merging dynamic shared memory. A_shared.dyn[(ramp(...), 1, 8) + // x8(17408))] = A_global[ramp(...),1, 8)] auto *add = store->indices[0].as(); if (!add->a->IsInstance()) return PrimExpr(); if (!add->b->IsInstance()) return PrimExpr(); return tir::Add(add->a.as()->base, add->b.as()->value); } return PrimExpr(); }(); if (src_offset.defined() && dst_offset.defined()) { return Evaluate(Call( store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes), predicate_value})); } } } } return StmtMutator::VisitStmt_(store); } Stmt VisitStmt_(const BufferStoreNode *store) { bool is_shared = (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn"); if (in_async && is_shared) { if (auto *load = store->value.as()) { return InjectPTX(load, store); } else if (auto *call = store->value.as()) { // tir.if_then_else is a call to tir::builtin::if_then_else() if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) { if (auto *load = call->args[1].as()) { // Only default value of 0 is supported since 0 is the default value // used by cp.async ptx. @see section 9.7.8.22.3. of // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations bool else_value_is_zero = false; if (auto *b = call->args[2].as()) { if (auto *f = b->value.as()) { else_value_is_zero = f->value == 0.0f; } else if (auto *i = b->value.as()) { else_value_is_zero = i->value == 0; } } if (auto *f = call->args[2].as()) { else_value_is_zero = f->value == 0.0f; } else if (auto *i = call->args[2].as()) { else_value_is_zero = i->value == 0; } if (else_value_is_zero) { return InjectPTX(load, store, true, call->args[0]); } } } } } return StmtMutator::VisitStmt_(store); } private: bool in_async{false}; }; using namespace tir::transform; tvm::transform::Pass InjectPTXAsyncCopy() { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { auto *n = f.CopyOnWrite(); n->body = PTXAsyncCopyInjector()(n->body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {}); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy); }); } // namespace tl } // namespace tvm