Commit 0211193c authored by zhuwenwen's avatar zhuwenwen
Browse files

initial llama

parents
Pipeline #509 failed with stages
in 0 seconds
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
import sys
if __name__ == '__main__':
print('grouping duplicate urls ...')
input = sys.argv[1]
output = sys.argv[2]
if len(sys.argv) > 3:
jaccard_similarity_threshold = float(sys.argv[3])
else:
jaccard_similarity_threshold = 0.7
url_to_index = {}
index_to_urls = []
counter = 0
start_time = time.time()
with open(input, 'r') as f:
for line in f:
counter += 1
myjson = json.loads(line)
urls = []
for main_url in myjson.keys():
urls.append(main_url)
for value in myjson[main_url]:
for other_url, js in value.items():
if js >= jaccard_similarity_threshold:
urls.append(other_url)
current_index = -1
other_indices = set()
for url in urls:
if url in url_to_index:
if current_index == -1:
current_index = url_to_index[url]
elif current_index != url_to_index[url]:
other_indices.add(url_to_index[url])
if current_index == -1:
current_index = len(index_to_urls)
index_to_urls.append(set())
for url in urls:
url_to_index[url] = current_index
index_to_urls[current_index].add(url)
for index in other_indices:
for url in index_to_urls[index]:
index_to_urls[current_index].add(url)
url_to_index[url] = current_index
index_to_urls[index] = None
if counter % 100000 == 0:
print(' > processed {} lines in {} seconds ...'.format(
counter, time.time() - start_time))
total_remove = 0
total_remain = 0
for urls in index_to_urls:
if urls is not None:
if len(urls) > 1:
total_remove += (len(urls) - 1)
total_remain += 1
print('out of {} urls, only {} are unique and {} should be removed'.format(
total_remove+total_remain, total_remain, total_remove))
with open(output, 'wb') as f:
for i, urls in enumerate(index_to_urls):
if urls is not None:
if len(urls) > 1:
myjson = json.dumps({str(i): list(urls)},
ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import sys
import json
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--json_path", type=str, default=".",
help="path where all the json files are located")
parser.add_argument("--output_file", type=str, default="merged_output.json",
help="filename where the merged json should go")
args = parser.parse_args()
json_path = args.json_path
out_file = args.output_file
json_files = glob.glob(json_path + '/*.json')
counter = 0
with open(out_file, 'w') as outfile:
for fname in json_files:
counter += 1
if counter % 1024 == 0:
print("Merging at ", counter, flush=True)
with open(fname, 'r') as infile:
for row in infile:
each_row = json.loads(row)
outfile.write(row)
print("Merged file", out_file, flush=True)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
import sys
if __name__ == '__main__':
url_filename = sys.argv[1]
data_filename = sys.argv[2]
output_filename = sys.argv[3]
urls = set()
with open(url_filename, 'r') as f:
for line in f:
myjson = json.loads(line)
for key in myjson:
this_urls = myjson[key]
for i in range(1, len(this_urls)):
urls.add(this_urls[i])
print('will be removing {} urls'.format(len(urls)), flush=True)
written_docs = 0
removed_docs = 0
removed_chars = 0
start_time = time.time()
with open(output_filename, 'wb') as fout:
with open(data_filename, 'r') as fin:
for line in fin:
try:
myjson = json.loads(line)
url = myjson['url']
if url in urls:
print('removing', myjson)
removed_docs += 1
removed_chars += len(myjson['text'])
continue
myjson = json.dumps(myjson, ensure_ascii=False)
fout.write(myjson.encode('utf-8'))
fout.write('\n'.encode('utf-8'))
written_docs += 1
if written_docs % 10000 == 0:
print(' [PROCESSED] time (s): {:.2f} | written: {} '
'| removed: {} (char: {})'.format(
time.time() - start_time,
written_docs, removed_docs, removed_chars))
except Exception as e:
print('[SKIPPING]', line, e)
print(' [PROCESSED] time (s): {:.2f} | written: {} '
'| removed: {} (char: {})'.format(
time.time() - start_time,
written_docs, removed_docs, removed_chars))
print('done :-)')
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processing data for pretraining."""
import argparse
import json
import multiprocessing
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
import torch
try:
import nltk
nltk_available = True
except ImportError:
nltk_available = False
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
_period_context_fmt = r"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""
class IdentitySplitter(object):
def tokenize(self, *text):
return text
class Encoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = build_tokenizer(self.args)
if self.args.split_sentences:
if not nltk_available:
print("NLTK is not available to split sentences.")
exit()
splitter = nltk.load("tokenizers/punkt/english.pickle")
if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text = splitter._params,
lang_vars = CustomLanguageVars())
else:
Encoder.splitter = splitter
else:
Encoder.splitter = IdentitySplitter()
def encode(self, json_line):
data = json.loads(json_line)
ids = {}
for key in self.args.json_keys:
text = data[key]
doc_ids = []
for sentence in Encoder.splitter.tokenize(text):
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.append(sentence_ids)
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
ids[key] = doc_ids
return ids, len(json_line)
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to input JSON')
group.add_argument('--json-keys', nargs='+', default=['text'],
help='space separate listed of keys to extract from json')
group.add_argument('--split-sentences', action='store_true',
help='Split documents into sentences.')
group.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences when splitting.')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
group.add_argument('--dataset-impl', type=str, default='mmap',
choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1,
help='Number of worker processes to launch')
group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
args = parser.parse_args()
args.keep_empty = False
if args.tokenizer_type.lower().startswith('bert'):
if not args.split_sentences:
print("Bert tokenizer detected, are you sure you don't want to split sentences?")
# some default/dummy values for the tokenizer
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0
return args
def main():
args = get_args()
startup_start = time.time()
print("Opening", args.input)
fin = open(args.input, 'r', encoding='utf-8')
if nltk_available and args.split_sentences:
nltk.download("punkt", quiet=True)
encoder = Encoder(args)
tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25)
#encoded_docs = map(encoder.encode, fin)
level = "document"
if args.split_sentences:
level = "sentence"
print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Output prefix: {args.output_prefix}")
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key, sentences in doc.items():
if len(sentences) == 0:
continue
for sentence in sentences:
builders[key].add_item(torch.IntTensor(sentence))
builders[key].end_document()
if i % args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {i} documents",
f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
for key in args.json_keys:
builders[key].finalize(output_idx_files[key])
if __name__ == '__main__':
main()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT"""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import socket
from megatron import get_args
from megatron import print_rank_0
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.text_generation_server import MegatronServer
from megatron.text_generation import generate_and_post_process
import torch
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process)
return model
def add_text_generate_args(parser):
group = parser.add_argument_group(title='text generation')
group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.')
group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0,
help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
return parser
if __name__ == "__main__":
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'no_load_rng': True,
'no_load_optim': True})
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint
model = get_model(model_provider, wrap_with_ddp=False)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
server = MegatronServer(model)
server.run("0.0.0.0")
while True:
choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice, 0)
if choice[0].item() == 0:
generate_and_post_process(model)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
import urllib2
class PutRequest(urllib2.Request):
'''class to handling putting with urllib2'''
def get_method(self, *args, **kwargs):
return 'PUT'
if __name__ == "__main__":
url = sys.argv[1]
while True:
sentence = raw_input("Enter prompt: ")
tokens_to_generate = int(input("Enter number of tokens to generate: "))
data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate})
req = PutRequest(url, data, {'Content-Type': 'application/json'})
response = urllib2.urlopen(req)
resp_sentences = json.load(response)
print("Megatron Response: ")
print(resp_sentences["text"][0])
#
# Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cmake_minimum_required(VERSION 3.8)
set(cuda_driver_wrapper_files
cudaDriverWrapper.cpp
)
add_library(cuda_driver_wrapper STATIC ${cuda_driver_wrapper_files})
target_link_libraries(cuda_driver_wrapper PRIVATE -lcublas -lcudart)
set_property(TARGET cuda_driver_wrapper PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cuda_driver_wrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
\ No newline at end of file
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define CUDA_LIB_NAME "cuda"
#if defined(_WIN32)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif // defined(WIN32_LEAN_AND_MEAN)
#include <windows.h>
#define dllOpen(name) (void*) LoadLibraryA("nv" name ".dll")
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
#define dllGetSym(handle, name) GetProcAddress(static_cast<HMODULE>(handle), name)
#else
#include <dlfcn.h>
#define dllOpen(name) dlopen("lib" name ".so", RTLD_LAZY)
#define dllClose(handle) dlclose(handle)
#define dllGetSym(handle, name) dlsym(handle, name)
#endif
#include "cudaDriverWrapper.h"
// #include "plugin.h"
#include <cuda.h>
#include <stdio.h>
// using namespace nvinfer1;
CUDADriverWrapper::CUDADriverWrapper()
{
handle = dllOpen(CUDA_LIB_NAME);
// ASSERT(handle != nullptr); // TODO check
auto load_sym = [](void* handle, const char* name) {
void* ret = dllGetSym(handle, name);
// ASSERT(ret != nullptr); // TODO check
return ret;
};
*(void**) (&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
*(void**) (&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
*(void**) (&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
*(void**) (&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
*(void**) (&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
*(void**) (&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
*(void**) (&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
*(void**) (&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
*(void**) (&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
*(void**) (&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
*(void**) (&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
*(void**) (&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
}
CUDADriverWrapper::~CUDADriverWrapper()
{
dllClose(handle);
}
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, const char** pStr) const
{
return (*_cuGetErrorName)(error, pStr);
}
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const
{
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
}
CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const
{
return (*_cuLinkComplete)(state, cubinOut, sizeOut);
}
CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const
{
return (*_cuModuleUnload)(hmod);
}
CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
{
return (*_cuLinkDestroy)(state);
}
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, const void* image) const
{
return (*_cuModuleLoadData)(module, image);
}
CUresult CUDADriverWrapper::cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const
{
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
}
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, const char* name) const
{
return (*_cuModuleGetFunction)(hfunc, hmod, name);
}
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path,
unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const
{
return (*_cuLaunchCooperativeKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams);
}
CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const
{
return (*_cuLaunchKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
}
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDA_DRIVER_WRAPPER_H
#define CUDA_DRIVER_WRAPPER_H
#include <cstdio>
#include <cuda.h>
#pragma once
#define cuErrCheck(stat, wrap) \
{ \
cuErrCheck_((stat), wrap, __FILE__, __LINE__); \
}
// namespace nvinfer1
// {
class CUDADriverWrapper
{
public:
CUDADriverWrapper();
~CUDADriverWrapper();
CUresult cuGetErrorName(CUresult error, const char** pStr) const;
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
CUresult cuModuleUnload(CUmodule hmod) const;
CUresult cuLinkDestroy(CUlinkState state) const;
CUresult cuModuleLoadData(CUmodule* module, const void* image) const;
CUresult cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, const char* name) const;
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path, unsigned int numOptions,
CUjit_option* options, void** optionValues) const;
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name,
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const;
CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra) const;
private:
void* handle;
CUresult (*_cuGetErrorName)(CUresult, const char**);
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
CUresult (*_cuModuleUnload)(CUmodule);
CUresult (*_cuLinkDestroy)(CUlinkState);
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
CUresult (*_cuModuleLoadData)(CUmodule*, const void*);
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, const char*);
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, const char*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLinkAddData)(
CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, CUstream, void**);
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra);
};
inline void cuErrCheck_(CUresult stat, const CUDADriverWrapper& wrap, const char* file, int line)
{
if (stat != CUDA_SUCCESS)
{
const char* msg = nullptr;
wrap.cuGetErrorName(stat, &msg);
fprintf(stderr, "CUDA Error: %s %s %d\n", msg, file, line);
}
}
// } // namespace nvinfer1
#endif // CUDA_DRIVER_WRAPPER_H
/******************************************************************************
* Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_ADJACENT_DIFFERENCE_HPP_
#define HIPCUB_ROCPRIM_BLOCK_BLOCK_ADJACENT_DIFFERENCE_HPP_
#include "../config.hpp"
#include <cub/rocprim/block/block_adjacent_difference.hpp>
BEGIN_HIPCUB_NAMESPACE
namespace detail
{
// Trait checks if FlagOp can be called with 3 arguments (a, b, b_index)
template<class T, class FlagOp, class = void>
struct WithBIndexArg
: std::false_type
{ };
template<class T, class FlagOp>
struct WithBIndexArg<
T, FlagOp,
typename std::conditional<
true,
void,
decltype(std::declval<FlagOp>()(std::declval<T>(), std::declval<T>(), 0))
>::type
> : std::true_type
{ };
}
template<
typename T,
int BLOCK_DIM_X,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1,
int ARCH = HIPCUB_ARCH /* ignored */
>
class BlockAdjacentDifference
: private ::rocprim::block_adjacent_difference<
T,
BLOCK_DIM_X,
BLOCK_DIM_Y,
BLOCK_DIM_Z
>
{
static_assert(
BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
"BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0"
);
using base_type =
typename ::rocprim::block_adjacent_difference<
T,
BLOCK_DIM_X,
BLOCK_DIM_Y,
BLOCK_DIM_Z
>;
// Reference to temporary storage (usually shared memory)
typename base_type::storage_type& temp_storage_;
public:
using TempStorage = typename base_type::storage_type;
HIPCUB_DEVICE inline
BlockAdjacentDifference() : temp_storage_(private_storage())
{
}
HIPCUB_DEVICE inline
BlockAdjacentDifference(TempStorage& temp_storage) : temp_storage_(temp_storage)
{
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
[[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]]
HIPCUB_DEVICE inline
void FlagHeads(FlagT (&head_flags)[ITEMS_PER_THREAD],
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated")
base_type::flag_heads(head_flags, input, flag_op, temp_storage_);
HIPCUB_CLANG_SUPPRESS_WARNING_POP
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
[[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]]
HIPCUB_DEVICE inline
void FlagHeads(FlagT (&head_flags)[ITEMS_PER_THREAD],
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op,
T tile_predecessor_item)
{
HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated")
base_type::flag_heads(head_flags, tile_predecessor_item, input, flag_op, temp_storage_);
HIPCUB_CLANG_SUPPRESS_WARNING_POP
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
[[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]]
HIPCUB_DEVICE inline
void FlagTails(FlagT (&tail_flags)[ITEMS_PER_THREAD],
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated")
base_type::flag_tails(tail_flags, input, flag_op, temp_storage_);
HIPCUB_CLANG_SUPPRESS_WARNING_POP
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
[[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]]
HIPCUB_DEVICE inline
void FlagTails(FlagT (&tail_flags)[ITEMS_PER_THREAD],
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op,
T tile_successor_item)
{
HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated")
base_type::flag_tails(tail_flags, tile_successor_item, input, flag_op, temp_storage_);
HIPCUB_CLANG_SUPPRESS_WARNING_POP
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
[[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]]
HIPCUB_DEVICE inline
void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD],
FlagT (&tail_flags)[ITEMS_PER_THREAD],
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated")
base_type::flag_heads_and_tails(
head_flags, tail_flags, input,
flag_op, temp_storage_
);
HIPCUB_CLANG_SUPPRESS_WARNING_POP
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
[[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]]
HIPCUB_DEVICE inline
void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD],
FlagT (&tail_flags)[ITEMS_PER_THREAD],
T tile_successor_item,
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated")
base_type::flag_heads_and_tails(
head_flags, tail_flags, tile_successor_item, input,
flag_op, temp_storage_
);
HIPCUB_CLANG_SUPPRESS_WARNING_POP
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
[[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]]
HIPCUB_DEVICE inline
void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD],
T tile_predecessor_item,
FlagT (&tail_flags)[ITEMS_PER_THREAD],
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated")
base_type::flag_heads_and_tails(
head_flags, tile_predecessor_item, tail_flags, input,
flag_op, temp_storage_
);
HIPCUB_CLANG_SUPPRESS_WARNING_POP
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
[[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]]
HIPCUB_DEVICE inline
void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD],
T tile_predecessor_item,
FlagT (&tail_flags)[ITEMS_PER_THREAD],
T tile_successor_item,
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated")
base_type::flag_heads_and_tails(
head_flags, tile_predecessor_item, tail_flags, tile_successor_item, input,
flag_op, temp_storage_
);
HIPCUB_CLANG_SUPPRESS_WARNING_POP
}
template <int ITEMS_PER_THREAD, typename OutputType, typename DifferenceOpT>
HIPCUB_DEVICE inline
void SubtractLeft(T (&input)[ITEMS_PER_THREAD],
OutputType (&output)[ITEMS_PER_THREAD],
DifferenceOpT difference_op)
{
base_type::subtract_left(
input, output, difference_op, temp_storage_
);
}
template <int ITEMS_PER_THREAD, typename OutputT, typename DifferenceOpT>
HIPCUB_DEVICE inline
void SubtractLeft(T (&input)[ITEMS_PER_THREAD],
OutputT (&output)[ITEMS_PER_THREAD],
DifferenceOpT difference_op,
T tile_predecessor_item)
{
base_type::subtract_left(
input, output, difference_op, tile_predecessor_item, temp_storage_
);
}
template <int ITEMS_PER_THREAD, typename OutputType, typename DifferenceOpT>
HIPCUB_DEVICE inline
void SubtractLeftPartialTile(T (&input)[ITEMS_PER_THREAD],
OutputType (&output)[ITEMS_PER_THREAD],
DifferenceOpT difference_op,
int valid_items)
{
base_type::subtract_left_partial(
input, output, difference_op, valid_items, temp_storage_
);
}
template <int ITEMS_PER_THREAD, typename OutputT, typename DifferenceOpT>
HIPCUB_DEVICE inline
void SubtractRight(T (&input)[ITEMS_PER_THREAD],
OutputT (&output)[ITEMS_PER_THREAD],
DifferenceOpT difference_op)
{
base_type::subtract_right(
input, output, difference_op, temp_storage_
);
}
template <int ITEMS_PER_THREAD, typename OutputT, typename DifferenceOpT>
HIPCUB_DEVICE inline
void SubtractRight(T (&input)[ITEMS_PER_THREAD],
OutputT (&output)[ITEMS_PER_THREAD],
DifferenceOpT difference_op,
T tile_successor_item)
{
base_type::subtract_right(
input, output, difference_op, tile_successor_item, temp_storage_
);
}
template <int ITEMS_PER_THREAD, typename OutputT, typename DifferenceOpT>
HIPCUB_DEVICE inline
void SubtractRightPartialTile(T (&input)[ITEMS_PER_THREAD],
OutputT (&output)[ITEMS_PER_THREAD],
DifferenceOpT difference_op,
int valid_items)
{
base_type::subtract_right_partial(
input, output, difference_op, valid_items, temp_storage_
);
}
private:
HIPCUB_DEVICE inline
TempStorage& private_storage()
{
HIPCUB_SHARED_MEMORY TempStorage private_storage;
return private_storage;
}
};
END_HIPCUB_NAMESPACE
#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_ADJACENT_DIFFERENCE_HPP_
/******************************************************************************
* Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_DISCONTINUITY_HPP_
#define HIPCUB_ROCPRIM_BLOCK_BLOCK_DISCONTINUITY_HPP_
#include "../config.hpp"
#include <cub/rocprim/block/block_discontinuity.hpp>
BEGIN_HIPCUB_NAMESPACE
template<
typename T,
int BLOCK_DIM_X,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1,
int ARCH = HIPCUB_ARCH /* ignored */
>
class BlockDiscontinuity
: private ::rocprim::block_discontinuity<
T,
BLOCK_DIM_X,
BLOCK_DIM_Y,
BLOCK_DIM_Z
>
{
static_assert(
BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
"BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0"
);
using base_type =
typename ::rocprim::block_discontinuity<
T,
BLOCK_DIM_X,
BLOCK_DIM_Y,
BLOCK_DIM_Z
>;
// Reference to temporary storage (usually shared memory)
typename base_type::storage_type& temp_storage_;
public:
using TempStorage = typename base_type::storage_type;
HIPCUB_DEVICE inline
BlockDiscontinuity() : temp_storage_(private_storage())
{
}
HIPCUB_DEVICE inline
BlockDiscontinuity(TempStorage& temp_storage) : temp_storage_(temp_storage)
{
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
HIPCUB_DEVICE inline
void FlagHeads(FlagT (&head_flags)[ITEMS_PER_THREAD],
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
base_type::flag_heads(head_flags, input, flag_op, temp_storage_);
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
HIPCUB_DEVICE inline
void FlagHeads(FlagT (&head_flags)[ITEMS_PER_THREAD],
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op,
T tile_predecessor_item)
{
base_type::flag_heads(head_flags, tile_predecessor_item, input, flag_op, temp_storage_);
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
HIPCUB_DEVICE inline
void FlagTails(FlagT (&tail_flags)[ITEMS_PER_THREAD],
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
base_type::flag_tails(tail_flags, input, flag_op, temp_storage_);
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
HIPCUB_DEVICE inline
void FlagTails(FlagT (&tail_flags)[ITEMS_PER_THREAD],
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op,
T tile_successor_item)
{
base_type::flag_tails(tail_flags, tile_successor_item, input, flag_op, temp_storage_);
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
HIPCUB_DEVICE inline
void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD],
FlagT (&tail_flags)[ITEMS_PER_THREAD],
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
base_type::flag_heads_and_tails(
head_flags, tail_flags, input,
flag_op, temp_storage_
);
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
HIPCUB_DEVICE inline
void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD],
FlagT (&tail_flags)[ITEMS_PER_THREAD],
T tile_successor_item,
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
base_type::flag_heads_and_tails(
head_flags, tail_flags, tile_successor_item, input,
flag_op, temp_storage_
);
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
HIPCUB_DEVICE inline
void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD],
T tile_predecessor_item,
FlagT (&tail_flags)[ITEMS_PER_THREAD],
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
base_type::flag_heads_and_tails(
head_flags, tile_predecessor_item, tail_flags, input,
flag_op, temp_storage_
);
}
template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
HIPCUB_DEVICE inline
void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD],
T tile_predecessor_item,
FlagT (&tail_flags)[ITEMS_PER_THREAD],
T tile_successor_item,
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
base_type::flag_heads_and_tails(
head_flags, tile_predecessor_item, tail_flags, tile_successor_item, input,
flag_op, temp_storage_
);
}
private:
HIPCUB_DEVICE inline
TempStorage& private_storage()
{
HIPCUB_SHARED_MEMORY TempStorage private_storage;
return private_storage;
}
};
END_HIPCUB_NAMESPACE
#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_DISCONTINUITY_HPP_
/******************************************************************************
* Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_EXCHANGE_HPP_
#define HIPCUB_ROCPRIM_BLOCK_BLOCK_EXCHANGE_HPP_
#include "../config.hpp"
#include <cub/rocprim/block/block_exchange.hpp>
BEGIN_HIPCUB_NAMESPACE
template<
typename InputT,
int BLOCK_DIM_X,
int ITEMS_PER_THREAD,
bool WARP_TIME_SLICING = false, /* ignored */
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1,
int ARCH = HIPCUB_ARCH /* ignored */
>
class BlockExchange
: private ::rocprim::block_exchange<
InputT,
BLOCK_DIM_X,
ITEMS_PER_THREAD,
BLOCK_DIM_Y,
BLOCK_DIM_Z
>
{
static_assert(
BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
"BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0"
);
using base_type =
typename ::rocprim::block_exchange<
InputT,
BLOCK_DIM_X,
ITEMS_PER_THREAD,
BLOCK_DIM_Y,
BLOCK_DIM_Z
>;
// Reference to temporary storage (usually shared memory)
typename base_type::storage_type& temp_storage_;
public:
using TempStorage = typename base_type::storage_type;
HIPCUB_DEVICE inline
BlockExchange() : temp_storage_(private_storage())
{
}
HIPCUB_DEVICE inline
BlockExchange(TempStorage& temp_storage) : temp_storage_(temp_storage)
{
}
template<typename OutputT>
HIPCUB_DEVICE inline
void StripedToBlocked(InputT (&input_items)[ITEMS_PER_THREAD],
OutputT (&output_items)[ITEMS_PER_THREAD])
{
base_type::striped_to_blocked(input_items, output_items, temp_storage_);
}
template<typename OutputT>
HIPCUB_DEVICE inline
void BlockedToStriped(InputT (&input_items)[ITEMS_PER_THREAD],
OutputT (&output_items)[ITEMS_PER_THREAD])
{
base_type::blocked_to_striped(input_items, output_items, temp_storage_);
}
template<typename OutputT>
HIPCUB_DEVICE inline
void WarpStripedToBlocked(InputT (&input_items)[ITEMS_PER_THREAD],
OutputT (&output_items)[ITEMS_PER_THREAD])
{
base_type::warp_striped_to_blocked(input_items, output_items, temp_storage_);
}
template<typename OutputT>
HIPCUB_DEVICE inline
void BlockedToWarpStriped(InputT (&input_items)[ITEMS_PER_THREAD],
OutputT (&output_items)[ITEMS_PER_THREAD])
{
base_type::blocked_to_warp_striped(input_items, output_items, temp_storage_);
}
template<typename OutputT, typename OffsetT>
HIPCUB_DEVICE inline
void ScatterToBlocked(InputT (&input_items)[ITEMS_PER_THREAD],
OutputT (&output_items)[ITEMS_PER_THREAD],
OffsetT (&ranks)[ITEMS_PER_THREAD])
{
base_type::scatter_to_blocked(input_items, output_items, ranks, temp_storage_);
}
template<typename OutputT, typename OffsetT>
HIPCUB_DEVICE inline
void ScatterToStriped(InputT (&input_items)[ITEMS_PER_THREAD],
OutputT (&output_items)[ITEMS_PER_THREAD],
OffsetT (&ranks)[ITEMS_PER_THREAD])
{
base_type::scatter_to_striped(input_items, output_items, ranks, temp_storage_);
}
template<typename OutputT, typename OffsetT>
HIPCUB_DEVICE inline
void ScatterToStripedGuarded(InputT (&input_items)[ITEMS_PER_THREAD],
OutputT (&output_items)[ITEMS_PER_THREAD],
OffsetT (&ranks)[ITEMS_PER_THREAD])
{
base_type::scatter_to_striped_guarded(input_items, output_items, ranks, temp_storage_);
}
template<typename OutputT, typename OffsetT, typename ValidFlag>
HIPCUB_DEVICE inline
void ScatterToStripedFlagged(InputT (&input_items)[ITEMS_PER_THREAD],
OutputT (&output_items)[ITEMS_PER_THREAD],
OffsetT (&ranks)[ITEMS_PER_THREAD],
ValidFlag (&is_valid)[ITEMS_PER_THREAD])
{
base_type::scatter_to_striped_flagged(input_items, output_items, ranks, is_valid, temp_storage_);
}
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
HIPCUB_DEVICE inline void StripedToBlocked(
InputT (&items)[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
StripedToBlocked(items, items);
}
HIPCUB_DEVICE inline void BlockedToStriped(
InputT (&items)[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
BlockedToStriped(items, items);
}
HIPCUB_DEVICE inline void WarpStripedToBlocked(
InputT (&items)[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
WarpStripedToBlocked(items, items);
}
HIPCUB_DEVICE inline void BlockedToWarpStriped(
InputT (&items)[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
BlockedToWarpStriped(items, items);
}
template <typename OffsetT>
HIPCUB_DEVICE inline void ScatterToBlocked(
InputT (&items)[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT (&ranks)[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToBlocked(items, items, ranks);
}
template <typename OffsetT>
HIPCUB_DEVICE inline void ScatterToStriped(
InputT (&items)[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT (&ranks)[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToStriped(items, items, ranks);
}
template <typename OffsetT>
HIPCUB_DEVICE inline void ScatterToStripedGuarded(
InputT (&items)[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT (&ranks)[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToStripedGuarded(items, items, ranks);
}
template <typename OffsetT, typename ValidFlag>
HIPCUB_DEVICE inline void ScatterToStripedFlagged(
InputT (&items)[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT (&ranks)[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks
ValidFlag (&is_valid)[ITEMS_PER_THREAD]) ///< [in] Corresponding flag denoting item validity
{
ScatterToStriped(items, items, ranks, is_valid);
}
#endif // DOXYGEN_SHOULD_SKIP_THIS
private:
HIPCUB_DEVICE inline
TempStorage& private_storage()
{
HIPCUB_SHARED_MEMORY TempStorage private_storage;
return private_storage;
}
};
END_HIPCUB_NAMESPACE
#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_EXCHANGE_HPP_
/******************************************************************************
* Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_HISTOGRAM_HPP_
#define HIPCUB_ROCPRIM_BLOCK_BLOCK_HISTOGRAM_HPP_
#include <type_traits>
#include <cub/rocprim/block/block_histogram.hpp>
BEGIN_HIPCUB_NAMESPACE
namespace detail
{
inline constexpr
typename std::underlying_type<::rocprim::block_histogram_algorithm>::type
to_BlockHistogramAlgorithm_enum(::rocprim::block_histogram_algorithm v)
{
using utype = std::underlying_type<::rocprim::block_histogram_algorithm>::type;
return static_cast<utype>(v);
}
}
enum BlockHistogramAlgorithm
{
BLOCK_HISTO_ATOMIC
= detail::to_BlockHistogramAlgorithm_enum(::rocprim::block_histogram_algorithm::using_atomic),
BLOCK_HISTO_SORT
= detail::to_BlockHistogramAlgorithm_enum(::rocprim::block_histogram_algorithm::using_sort)
};
template<
typename T,
int BLOCK_DIM_X,
int ITEMS_PER_THREAD,
int BINS,
BlockHistogramAlgorithm ALGORITHM = BLOCK_HISTO_SORT,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1,
int ARCH = HIPCUB_ARCH /* ignored */
>
class BlockHistogram
: private ::rocprim::block_histogram<
T,
BLOCK_DIM_X,
ITEMS_PER_THREAD,
BINS,
static_cast<::rocprim::block_histogram_algorithm>(ALGORITHM),
BLOCK_DIM_Y,
BLOCK_DIM_Z
>
{
static_assert(
BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
"BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0"
);
using base_type =
typename ::rocprim::block_histogram<
T,
BLOCK_DIM_X,
ITEMS_PER_THREAD,
BINS,
static_cast<::rocprim::block_histogram_algorithm>(ALGORITHM),
BLOCK_DIM_Y,
BLOCK_DIM_Z
>;
// Reference to temporary storage (usually shared memory)
typename base_type::storage_type& temp_storage_;
public:
using TempStorage = typename base_type::storage_type;
HIPCUB_DEVICE inline
BlockHistogram() : temp_storage_(private_storage())
{
}
HIPCUB_DEVICE inline
BlockHistogram(TempStorage& temp_storage) : temp_storage_(temp_storage)
{
}
template<class CounterT>
HIPCUB_DEVICE inline
void InitHistogram(CounterT histogram[BINS])
{
base_type::init_histogram(histogram);
}
template<class CounterT>
HIPCUB_DEVICE inline
void Composite(T (&items)[ITEMS_PER_THREAD],
CounterT histogram[BINS])
{
base_type::composite(items, histogram, temp_storage_);
}
template<class CounterT>
HIPCUB_DEVICE inline
void Histogram(T (&items)[ITEMS_PER_THREAD],
CounterT histogram[BINS])
{
base_type::init_histogram(histogram);
CTA_SYNC();
base_type::composite(items, histogram, temp_storage_);
}
private:
HIPCUB_DEVICE inline
TempStorage& private_storage()
{
HIPCUB_SHARED_MEMORY TempStorage private_storage;
return private_storage;
}
};
END_HIPCUB_NAMESPACE
#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_HISTOGRAM_HPP_
/******************************************************************************
* Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_LOAD_HPP_
#define HIPCUB_ROCPRIM_BLOCK_BLOCK_LOAD_HPP_
#include <type_traits>
#include "../config.hpp"
#include <cub/rocprim/block/block_load.hpp>
#include "block_load_func.cuh"
BEGIN_HIPCUB_NAMESPACE
namespace detail
{
inline constexpr
typename std::underlying_type<::rocprim::block_load_method>::type
to_BlockLoadAlgorithm_enum(::rocprim::block_load_method v)
{
using utype = std::underlying_type<::rocprim::block_load_method>::type;
return static_cast<utype>(v);
}
}
enum BlockLoadAlgorithm
{
BLOCK_LOAD_DIRECT
= detail::to_BlockLoadAlgorithm_enum(::rocprim::block_load_method::block_load_direct),
BLOCK_LOAD_STRIPED
= detail::to_BlockLoadAlgorithm_enum(::rocprim::block_load_method::block_load_striped),
BLOCK_LOAD_VECTORIZE
= detail::to_BlockLoadAlgorithm_enum(::rocprim::block_load_method::block_load_vectorize),
BLOCK_LOAD_TRANSPOSE
= detail::to_BlockLoadAlgorithm_enum(::rocprim::block_load_method::block_load_transpose),
BLOCK_LOAD_WARP_TRANSPOSE
= detail::to_BlockLoadAlgorithm_enum(::rocprim::block_load_method::block_load_warp_transpose),
BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED
= detail::to_BlockLoadAlgorithm_enum(::rocprim::block_load_method::block_load_warp_transpose)
};
template<
typename T,
int BLOCK_DIM_X,
int ITEMS_PER_THREAD,
BlockLoadAlgorithm ALGORITHM = BLOCK_LOAD_DIRECT,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1,
int ARCH = HIPCUB_ARCH /* ignored */
>
class BlockLoad
: private ::rocprim::block_load<
T,
BLOCK_DIM_X,
ITEMS_PER_THREAD,
static_cast<::rocprim::block_load_method>(ALGORITHM),
BLOCK_DIM_Y,
BLOCK_DIM_Z
>
{
static_assert(
BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
"BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0"
);
using base_type =
typename ::rocprim::block_load<
T,
BLOCK_DIM_X,
ITEMS_PER_THREAD,
static_cast<::rocprim::block_load_method>(ALGORITHM),
BLOCK_DIM_Y,
BLOCK_DIM_Z
>;
// Reference to temporary storage (usually shared memory)
typename base_type::storage_type& temp_storage_;
public:
using TempStorage = typename base_type::storage_type;
HIPCUB_DEVICE inline
BlockLoad() : temp_storage_(private_storage())
{
}
HIPCUB_DEVICE inline
BlockLoad(TempStorage& temp_storage) : temp_storage_(temp_storage)
{
}
template<class InputIteratorT>
HIPCUB_DEVICE inline
void Load(InputIteratorT block_iter,
T (&items)[ITEMS_PER_THREAD])
{
base_type::load(block_iter, items, temp_storage_);
}
template<class InputIteratorT>
HIPCUB_DEVICE inline
void Load(InputIteratorT block_iter,
T (&items)[ITEMS_PER_THREAD],
int valid_items)
{
base_type::load(block_iter, items, valid_items, temp_storage_);
}
template<
class InputIteratorT,
class Default
>
HIPCUB_DEVICE inline
void Load(InputIteratorT block_iter,
T (&items)[ITEMS_PER_THREAD],
int valid_items,
Default oob_default)
{
base_type::load(block_iter, items, valid_items, oob_default, temp_storage_);
}
private:
HIPCUB_DEVICE inline
TempStorage& private_storage()
{
HIPCUB_SHARED_MEMORY TempStorage private_storage;
return private_storage;
}
};
END_HIPCUB_NAMESPACE
#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_LOAD_HPP_
/******************************************************************************
* Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_LOAD_FUNC_HPP_
#define HIPCUB_ROCPRIM_BLOCK_BLOCK_LOAD_FUNC_HPP_
#include "../config.hpp"
#include <cub/rocprim/block/block_load_func.hpp>
BEGIN_HIPCUB_NAMESPACE
template<
typename T,
int ITEMS_PER_THREAD,
typename InputIteratorT
>
HIPCUB_DEVICE inline
void LoadDirectBlocked(int linear_id,
InputIteratorT block_iter,
T (&items)[ITEMS_PER_THREAD])
{
::rocprim::block_load_direct_blocked(
linear_id, block_iter, items
);
}
template<
typename T,
int ITEMS_PER_THREAD,
typename InputIteratorT
>
HIPCUB_DEVICE inline
void LoadDirectBlocked(int linear_id,
InputIteratorT block_iter,
T (&items)[ITEMS_PER_THREAD],
int valid_items)
{
::rocprim::block_load_direct_blocked(
linear_id, block_iter, items, valid_items
);
}
template<
typename T,
typename Default,
int ITEMS_PER_THREAD,
typename InputIteratorT
>
HIPCUB_DEVICE inline
void LoadDirectBlocked(int linear_id,
InputIteratorT block_iter,
T (&items)[ITEMS_PER_THREAD],
int valid_items,
Default oob_default)
{
::rocprim::block_load_direct_blocked(
linear_id, block_iter, items, valid_items, oob_default
);
}
template <
typename T,
int ITEMS_PER_THREAD
>
HIPCUB_DEVICE inline
void LoadDirectBlockedVectorized(int linear_id,
T* block_iter,
T (&items)[ITEMS_PER_THREAD])
{
::rocprim::block_load_direct_blocked_vectorized(
linear_id, block_iter, items
);
}
template<
int BLOCK_THREADS,
typename T,
int ITEMS_PER_THREAD,
typename InputIteratorT
>
HIPCUB_DEVICE inline
void LoadDirectStriped(int linear_id,
InputIteratorT block_iter,
T (&items)[ITEMS_PER_THREAD])
{
::rocprim::block_load_direct_striped<BLOCK_THREADS>(
linear_id, block_iter, items
);
}
template<
int BLOCK_THREADS,
typename T,
int ITEMS_PER_THREAD,
typename InputIteratorT
>
HIPCUB_DEVICE inline
void LoadDirectStriped(int linear_id,
InputIteratorT block_iter,
T (&items)[ITEMS_PER_THREAD],
int valid_items)
{
::rocprim::block_load_direct_striped<BLOCK_THREADS>(
linear_id, block_iter, items, valid_items
);
}
template<
int BLOCK_THREADS,
typename T,
typename Default,
int ITEMS_PER_THREAD,
typename InputIteratorT
>
HIPCUB_DEVICE inline
void LoadDirectStriped(int linear_id,
InputIteratorT block_iter,
T (&items)[ITEMS_PER_THREAD],
int valid_items,
Default oob_default)
{
::rocprim::block_load_direct_striped<BLOCK_THREADS>(
linear_id, block_iter, items, valid_items, oob_default
);
}
template<
typename T,
int ITEMS_PER_THREAD,
typename InputIteratorT
>
HIPCUB_DEVICE inline
void LoadDirectWarpStriped(int linear_id,
InputIteratorT block_iter,
T (&items)[ITEMS_PER_THREAD])
{
::rocprim::block_load_direct_warp_striped(
linear_id, block_iter, items
);
}
template<
typename T,
int ITEMS_PER_THREAD,
typename InputIteratorT
>
HIPCUB_DEVICE inline
void LoadDirectWarpStriped(int linear_id,
InputIteratorT block_iter,
T (&items)[ITEMS_PER_THREAD],
int valid_items)
{
::rocprim::block_load_direct_warp_striped(
linear_id, block_iter, items, valid_items
);
}
template<
typename T,
typename Default,
int ITEMS_PER_THREAD,
typename InputIteratorT
>
HIPCUB_DEVICE inline
void LoadDirectWarpStriped(int linear_id,
InputIteratorT block_iter,
T (&items)[ITEMS_PER_THREAD],
int valid_items,
Default oob_default)
{
::rocprim::block_load_direct_warp_striped(
linear_id, block_iter, items, valid_items, oob_default
);
}
END_HIPCUB_NAMESPACE
#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_LOAD_FUNC_HPP_
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
#define HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
#include "../thread/thread_sort.hpp"
#include "../util_math.cuh"
#include "../util_type.cuh"
#include <cub/rocprim/detail/various.hpp>
#include <cub/rocprim/functional.hpp>
BEGIN_HIPCUB_NAMESPACE
// Additional details of the Merge-Path Algorithm can be found in:
// S. Odeh, O. Green, Z. Mwassi, O. Shmueli, Y. Birk, " Merge Path - Parallel
// Merging Made Simple", Multithreaded Architectures and Applications (MTAAP)
// Workshop, IEEE 26th International Parallel & Distributed Processing
// Symposium (IPDPS), 2012
template <typename KeyT,
typename KeyIteratorT,
typename OffsetT,
typename BinaryPred>
HIPCUB_DEVICE __forceinline__ OffsetT MergePath(KeyIteratorT keys1,
KeyIteratorT keys2,
OffsetT keys1_count,
OffsetT keys2_count,
OffsetT diag,
BinaryPred binary_pred)
{
OffsetT keys1_begin = diag < keys2_count ? 0 : diag - keys2_count;
OffsetT keys1_end = (::rocprim::min)(diag, keys1_count);
while (keys1_begin < keys1_end)
{
OffsetT mid = cub::MidPoint<OffsetT>(keys1_begin, keys1_end);
KeyT key1 = keys1[mid];
KeyT key2 = keys2[diag - 1 - mid];
bool pred = binary_pred(key2, key1);
if (pred)
{
keys1_end = mid;
}
else
{
keys1_begin = mid + 1;
}
}
return keys1_begin;
}
template <typename KeyT, typename CompareOp, int ITEMS_PER_THREAD>
HIPCUB_DEVICE __forceinline__ void SerialMerge(KeyT *keys_shared,
int keys1_beg,
int keys2_beg,
int keys1_count,
int keys2_count,
KeyT (&output)[ITEMS_PER_THREAD],
int (&indices)[ITEMS_PER_THREAD],
CompareOp compare_op)
{
int keys1_end = keys1_beg + keys1_count;
int keys2_end = keys2_beg + keys2_count;
KeyT key1 = keys_shared[keys1_beg];
KeyT key2 = keys_shared[keys2_beg];
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
bool p = (keys2_beg < keys2_end) &&
((keys1_beg >= keys1_end)
|| compare_op(key2, key1));
output[item] = p ? key2 : key1;
indices[item] = p ? keys2_beg++ : keys1_beg++;
if (p)
{
key2 = keys_shared[keys2_beg];
}
else
{
key1 = keys_shared[keys1_beg];
}
}
}
/**
* @brief Generalized merge sort algorithm
*
* This class is used to reduce code duplication. Warp and Block merge sort
* differ only in how they compute thread index and how they synchronize
* threads. Since synchronization might require access to custom data
* (like member mask), CRTP is used.
*
* @par
* The code snippet below illustrates the way this class can be used.
* @par
* @code
* #include <hipcub/hipcub.hpp> // or equivalently <hipcub/block/block_merge_sort.hpp>
*
* constexpr int BLOCK_THREADS = 256;
* constexpr int ITEMS_PER_THREAD = 9;
*
* class BlockMergeSort : public BlockMergeSortStrategy<int,
* cub::NullType,
* BLOCK_THREADS,
* ITEMS_PER_THREAD,
* BlockMergeSort>
* {
* using BlockMergeSortStrategyT =
* BlockMergeSortStrategy<int,
* cub::NullType,
* BLOCK_THREADS,
* ITEMS_PER_THREAD,
* BlockMergeSort>;
* public:
* __device__ __forceinline__ explicit BlockMergeSort(
* typename BlockMergeSortStrategyT::TempStorage &temp_storage)
* : BlockMergeSortStrategyT(temp_storage, threadIdx.x)
* {}
*
* __device__ __forceinline__ void SyncImplementation() const
* {
* __syncthreads();
* }
* };
* @endcode
*
* @tparam KeyT
* KeyT type
*
* @tparam ValueT
* ValueT type. cub::NullType indicates a keys-only sort
*
* @tparam SynchronizationPolicy
* Provides a way of synchronizing threads. Should be derived from
* `BlockMergeSortStrategy`.
*/
template <typename KeyT,
typename ValueT,
int NUM_THREADS,
int ITEMS_PER_THREAD,
typename SynchronizationPolicy>
class BlockMergeSortStrategy
{
static_assert(PowerOfTwo<NUM_THREADS>::VALUE,
"NUM_THREADS must be a power of two");
private:
static constexpr int ITEMS_PER_TILE = ITEMS_PER_THREAD * NUM_THREADS;
// Whether or not there are values to be trucked along with keys
static constexpr bool KEYS_ONLY = ::rocprim::Equals<ValueT, NullType>::VALUE;
/// Shared memory type required by this thread block
union _TempStorage
{
KeyT keys_shared[ITEMS_PER_TILE + 1];
ValueT items_shared[ITEMS_PER_TILE + 1];
}; // union TempStorage
/// Shared storage reference
_TempStorage &temp_storage;
/// Internal storage allocator
HIPCUB_DEVICE __forceinline__ _TempStorage& PrivateStorage()
{
__shared__ _TempStorage private_storage;
return private_storage;
}
const unsigned int linear_tid;
public:
/// \smemstorage{BlockMergeSort}
struct TempStorage : Uninitialized<_TempStorage> {};
BlockMergeSortStrategy() = delete;
explicit HIPCUB_DEVICE __forceinline__
BlockMergeSortStrategy(unsigned int linear_tid)
: temp_storage(PrivateStorage())
, linear_tid(linear_tid)
{}
HIPCUB_DEVICE __forceinline__ BlockMergeSortStrategy(TempStorage &temp_storage,
unsigned int linear_tid)
: temp_storage(temp_storage.Alias())
, linear_tid(linear_tid)
{}
HIPCUB_DEVICE __forceinline__ unsigned int get_linear_tid() const
{
return linear_tid;
}
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* Sort is not guaranteed to be stable. That is, suppose that i and j are
* equivalent: neither one is less than the other. It is not guaranteed
* that the relative order of these two elements will be preserved by sort.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @param[in,out] keys
* Keys to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp>
HIPCUB_DEVICE __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
CompareOp compare_op)
{
ValueT items[ITEMS_PER_THREAD];
Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
}
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* - Sort is not guaranteed to be stable. That is, suppose that `i` and `j`
* are equivalent: neither one is less than the other. It is not guaranteed
* that the relative order of these two elements will be preserved by sort.
* - The value of `oob_default` is assigned to all elements that are out of
* `valid_items` boundaries. It's expected that `oob_default` is ordered
* after any value in the `valid_items` boundaries. The algorithm always
* sorts a fixed amount of elements, which is equal to
* `ITEMS_PER_THREAD * BLOCK_THREADS`. If there is a value that is ordered
* after `oob_default`, it won't be placed within `valid_items` boundaries.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @param[in,out] keys
* Keys to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* @param[in] valid_items
* Number of valid items to sort
*
* @param[in] oob_default
* Default value to assign out-of-bound items
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp>
HIPCUB_DEVICE __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
CompareOp compare_op,
int valid_items,
KeyT oob_default)
{
ValueT items[ITEMS_PER_THREAD];
Sort<CompareOp, true>(keys, items, compare_op, valid_items, oob_default);
}
/**
* @brief Sorts items partitioned across a CUDA thread block using a merge sorting method.
*
* @par
* Sort is not guaranteed to be stable. That is, suppose that `i` and `j` are
* equivalent: neither one is less than the other. It is not guaranteed
* that the relative order of these two elements will be preserved by sort.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @param[in,out] keys
* Keys to sort
*
* @param[in,out] items
* Values to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp>
HIPCUB_DEVICE __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&items)[ITEMS_PER_THREAD],
CompareOp compare_op)
{
Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
}
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* - Sort is not guaranteed to be stable. That is, suppose that `i` and `j`
* are equivalent: neither one is less than the other. It is not guaranteed
* that the relative order of these two elements will be preserved by sort.
* - The value of `oob_default` is assigned to all elements that are out of
* `valid_items` boundaries. It's expected that `oob_default` is ordered
* after any value in the `valid_items` boundaries. The algorithm always
* sorts a fixed amount of elements, which is equal to
* `ITEMS_PER_THREAD * BLOCK_THREADS`. If there is a value that is ordered
* after `oob_default`, it won't be placed within `valid_items` boundaries.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @tparam IS_LAST_TILE
* True if `valid_items` isn't equal to the `ITEMS_PER_TILE`
*
* @param[in,out] keys
* Keys to sort
*
* @param[in,out] items
* Values to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* @param[in] valid_items
* Number of valid items to sort
*
* @param[in] oob_default
* Default value to assign out-of-bound items
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp,
bool IS_LAST_TILE = true>
HIPCUB_DEVICE __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&items)[ITEMS_PER_THREAD],
CompareOp compare_op,
int valid_items,
KeyT oob_default)
{
if (IS_LAST_TILE)
{
// if last tile, find valid max_key
// and fill the remaining keys with it
//
KeyT max_key = oob_default;
#pragma unroll
for (int item = 1; item < ITEMS_PER_THREAD; ++item)
{
if (ITEMS_PER_THREAD * static_cast<int>(linear_tid) + item < valid_items)
{
max_key = compare_op(max_key, keys[item]) ? keys[item] : max_key;
}
else
{
keys[item] = max_key;
}
}
}
// if first element of thread is in input range, stable sort items
//
if (!IS_LAST_TILE || ITEMS_PER_THREAD * static_cast<int>(linear_tid) < valid_items)
{
StableOddEvenSort(keys, items, compare_op);
}
// each thread has sorted keys
// merge sort keys in shared memory
//
#pragma unroll
for (int target_merged_threads_number = 2;
target_merged_threads_number <= NUM_THREADS;
target_merged_threads_number *= 2)
{
int merged_threads_number = target_merged_threads_number / 2;
int mask = target_merged_threads_number - 1;
Sync();
// store keys in shmem
//
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
int idx = ITEMS_PER_THREAD * linear_tid + item;
temp_storage.keys_shared[idx] = keys[item];
}
Sync();
int indices[ITEMS_PER_THREAD];
int first_thread_idx_in_thread_group_being_merged = ~mask & linear_tid;
int start = ITEMS_PER_THREAD * first_thread_idx_in_thread_group_being_merged;
int size = ITEMS_PER_THREAD * merged_threads_number;
int thread_idx_in_thread_group_being_merged = mask & linear_tid;
int diag =
(::rocprim::min)(valid_items,
ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged);
int keys1_beg = (::rocprim::min)(valid_items, start);
int keys1_end = (::rocprim::min)(valid_items, keys1_beg + size);
int keys2_beg = keys1_end;
int keys2_end = (::rocprim::min)(valid_items, keys2_beg + size);
int keys1_count = keys1_end - keys1_beg;
int keys2_count = keys2_end - keys2_beg;
int partition_diag = MergePath<KeyT>(&temp_storage.keys_shared[keys1_beg],
&temp_storage.keys_shared[keys2_beg],
keys1_count,
keys2_count,
diag,
compare_op);
int keys1_beg_loc = keys1_beg + partition_diag;
int keys1_end_loc = keys1_end;
int keys2_beg_loc = keys2_beg + diag - partition_diag;
int keys2_end_loc = keys2_end;
int keys1_count_loc = keys1_end_loc - keys1_beg_loc;
int keys2_count_loc = keys2_end_loc - keys2_beg_loc;
SerialMerge(&temp_storage.keys_shared[0],
keys1_beg_loc,
keys2_beg_loc,
keys1_count_loc,
keys2_count_loc,
keys,
indices,
compare_op);
if (!KEYS_ONLY)
{
Sync();
// store keys in shmem
//
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
int idx = ITEMS_PER_THREAD * linear_tid + item;
temp_storage.items_shared[idx] = items[item];
}
Sync();
// gather items from shmem
//
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
items[item] = temp_storage.items_shared[indices[item]];
}
}
}
} // func block_merge_sort
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* StableSort is stable: it preserves the relative ordering of equivalent
* elements. That is, if `x` and `y` are elements such that `x` precedes `y`,
* and if the two elements are equivalent (neither `x < y` nor `y < x`) then
* a postcondition of StableSort is that `x` still precedes `y`.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @param[in,out] keys
* Keys to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp>
HIPCUB_DEVICE __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
CompareOp compare_op)
{
Sort(keys, compare_op);
}
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* StableSort is stable: it preserves the relative ordering of equivalent
* elements. That is, if `x` and `y` are elements such that `x` precedes `y`,
* and if the two elements are equivalent (neither `x < y` nor `y < x`) then
* a postcondition of StableSort is that `x` still precedes `y`.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @param[in,out] keys
* Keys to sort
*
* @param[in,out] items
* Values to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp>
HIPCUB_DEVICE __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&items)[ITEMS_PER_THREAD],
CompareOp compare_op)
{
Sort(keys, items, compare_op);
}
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* - StableSort is stable: it preserves the relative ordering of equivalent
* elements. That is, if `x` and `y` are elements such that `x` precedes
* `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`)
* then a postcondition of StableSort is that `x` still precedes `y`.
* - The value of `oob_default` is assigned to all elements that are out of
* `valid_items` boundaries. It's expected that `oob_default` is ordered
* after any value in the `valid_items` boundaries. The algorithm always
* sorts a fixed amount of elements, which is equal to
* `ITEMS_PER_THREAD * BLOCK_THREADS`.
* If there is a value that is ordered after `oob_default`, it won't be
* placed within `valid_items` boundaries.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @param[in,out] keys
* Keys to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* @param[in] valid_items
* Number of valid items to sort
*
* @param[in] oob_default
* Default value to assign out-of-bound items
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp>
HIPCUB_DEVICE __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
CompareOp compare_op,
int valid_items,
KeyT oob_default)
{
Sort(keys, compare_op, valid_items, oob_default);
}
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* - StableSort is stable: it preserves the relative ordering of equivalent
* elements. That is, if `x` and `y` are elements such that `x` precedes
* `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`)
* then a postcondition of StableSort is that `x` still precedes `y`.
* - The value of `oob_default` is assigned to all elements that are out of
* `valid_items` boundaries. It's expected that `oob_default` is ordered
* after any value in the `valid_items` boundaries. The algorithm always
* sorts a fixed amount of elements, which is equal to
* `ITEMS_PER_THREAD * BLOCK_THREADS`. If there is a value that is ordered
* after `oob_default`, it won't be placed within `valid_items` boundaries.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @tparam IS_LAST_TILE
* True if `valid_items` isn't equal to the `ITEMS_PER_TILE`
*
* @param[in,out] keys
* Keys to sort
*
* @param[in,out] items
* Values to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* @param[in] valid_items
* Number of valid items to sort
*
* @param[in] oob_default
* Default value to assign out-of-bound items
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp,
bool IS_LAST_TILE = true>
HIPCUB_DEVICE __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&items)[ITEMS_PER_THREAD],
CompareOp compare_op,
int valid_items,
KeyT oob_default)
{
Sort<CompareOp, IS_LAST_TILE>(keys,
items,
compare_op,
valid_items,
oob_default);
}
private:
HIPCUB_DEVICE __forceinline__ void Sync() const
{
static_cast<const SynchronizationPolicy*>(this)->SyncImplementation();
}
};
/**
* @brief The BlockMergeSort class provides methods for sorting items
* partitioned across a CUDA thread block using a merge sorting method.
* @ingroup BlockModule
*
* @tparam KeyT
* KeyT type
*
* @tparam BLOCK_DIM_X
* The thread block length in threads along the X dimension
*
* @tparam ITEMS_PER_THREAD
* The number of items per thread
*
* @tparam ValueT
* **[optional]** ValueT type (default: `cub::NullType`, which indicates
* a keys-only sort)
*
* @tparam BLOCK_DIM_Y
* **[optional]** The thread block length in threads along the Y dimension
* (default: 1)
*
* @tparam BLOCK_DIM_Z
* **[optional]** The thread block length in threads along the Z dimension
* (default: 1)
*
* @par Overview
* BlockMergeSort arranges items into ascending order using a comparison
* functor with less-than semantics. Merge sort can handle arbitrary types
* and comparison functors, but is slower than BlockRadixSort when sorting
* arithmetic types into ascending/descending order.
*
* @par A Simple Example
* @blockcollective{BlockMergeSort}
* @par
* The code snippet below illustrates a sort of 512 integer keys that are
* partitioned across 128 threads * where each thread owns 4 consecutive items.
* @par
* @code
* #include <hipcub/hipcub.hpp> // or equivalently <hipcub/block/block_merge_sort.hpp>
*
* struct CustomLess
* {
* template <typename DataType>
* __device__ bool operator()(const DataType &lhs, const DataType &rhs)
* {
* return lhs < rhs;
* }
* };
*
* __global__ void ExampleKernel(...)
* {
* // Specialize BlockMergeSort for a 1D block of 128 threads owning 4 integer items each
* typedef cub::BlockMergeSort<int, 128, 4> BlockMergeSort;
*
* // Allocate shared memory for BlockMergeSort
* __shared__ typename BlockMergeSort::TempStorage temp_storage_shuffle;
*
* // Obtain a segment of consecutive items that are blocked across threads
* int thread_keys[4];
* ...
*
* BlockMergeSort(temp_storage_shuffle).Sort(thread_keys, CustomLess());
* ...
* }
* @endcode
* @par
* Suppose the set of input `thread_keys` across the block of threads is
* `{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`.
* The corresponding output `thread_keys` in those threads will be
* `{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }`.
*
* @par Re-using dynamically allocating shared memory
* The following example under the examples/block folder illustrates usage of
* dynamically shared memory with BlockReduce and how to re-purpose
* the same memory region:
* <a href="../../examples/block/example_block_reduce_dyn_smem.cu">example_block_reduce_dyn_smem.cu</a>
*
* This example can be easily adapted to the storage required by BlockMergeSort.
*/
template <typename KeyT,
int BLOCK_DIM_X,
int ITEMS_PER_THREAD,
typename ValueT = NullType,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1>
class BlockMergeSort
: public BlockMergeSortStrategy<KeyT,
ValueT,
BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
ITEMS_PER_THREAD,
BlockMergeSort<KeyT,
BLOCK_DIM_X,
ITEMS_PER_THREAD,
ValueT,
BLOCK_DIM_Y,
BLOCK_DIM_Z>>
{
private:
// The thread block size in threads
static constexpr int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z;
static constexpr int ITEMS_PER_TILE = ITEMS_PER_THREAD * BLOCK_THREADS;
using BlockMergeSortStrategyT =
BlockMergeSortStrategy<KeyT,
ValueT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
BlockMergeSort>;
public:
HIPCUB_DEVICE __forceinline__ BlockMergeSort()
: BlockMergeSortStrategyT(
RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
{}
HIPCUB_DEVICE __forceinline__ explicit BlockMergeSort(
typename BlockMergeSortStrategyT::TempStorage &temp_storage)
: BlockMergeSortStrategyT(
temp_storage,
RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
{}
private:
HIPCUB_DEVICE __forceinline__ void SyncImplementation() const
{
CTA_SYNC();
}
friend BlockMergeSortStrategyT;
};
END_HIPCUB_NAMESPACE
#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/**
* \file
* cub::BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block
*/
#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_
#define HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_
#include <stdint.h>
#include "../config.hpp"
#include "../util_type.cuh"
#include "../util_ptx.cuh"
#include "../thread/thread_reduce.cuh"
#include "../thread/thread_scan.cuh"
#include "../block/block_scan.cuh"
#include "../block/radix_rank_sort_operations.hpp"
BEGIN_HIPCUB_NAMESPACE
/**
* \brief BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block.
* \ingroup BlockModule
*
* \tparam BLOCK_DIM_X The thread block length in threads along the X dimension
* \tparam RADIX_BITS The number of radix bits per digit place
* \tparam IS_DESCENDING Whether or not the sorted-order is high-to-low
* \tparam MEMOIZE_OUTER_SCAN <b>[optional]</b> Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure (default: true for architectures SM35 and newer, false otherwise). See BlockScanAlgorithm::BLOCK_SCAN_RAKING_MEMOIZE for more details.
* \tparam INNER_SCAN_ALGORITHM <b>[optional]</b> The cub::BlockScanAlgorithm algorithm to use (default: cub::BLOCK_SCAN_WARP_SCANS)
* \tparam SMEM_CONFIG <b>[optional]</b> Shared memory bank mode (default: \p cudaSharedMemBankSizeFourByte)
* \tparam BLOCK_DIM_Y <b>[optional]</b> The thread block length in threads along the Y dimension (default: 1)
* \tparam BLOCK_DIM_Z <b>[optional]</b> The thread block length in threads along the Z dimension (default: 1)
* \tparam ARCH <b>[optional]</b> \ptxversion
*
* \par Overview
* Blah...
* - Keys must be in a form suitable for radix ranking (i.e., unsigned bits).
* - \blocked
*
* \par Performance Considerations
* - \granularity
*
* \par Examples
* \par
* - <b>Example 1:</b> Simple radix rank of 32-bit integer keys
* \code
* #include <hipcub/hipcub.hpp>
*
* template <int BLOCK_THREADS>
* __global__ void ExampleKernel(...)
* {
*
* \endcode
*/
template <
int BLOCK_DIM_X,
int RADIX_BITS,
bool IS_DESCENDING,
bool MEMOIZE_OUTER_SCAN = false,
BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1,
int ARCH = HIPCUB_ARCH /* ignored */>
class BlockRadixRank
{
private:
/******************************************************************************
* Type definitions and constants
******************************************************************************/
// Integer type for digit counters (to be packed into words of type PackedCounters)
typedef unsigned short DigitCounter;
// Integer type for packing DigitCounters into columns of shared memory banks
typedef typename std::conditional<(SMEM_CONFIG == cudaSharedMemBankSizeEightByte),
unsigned long long,
unsigned int>::type PackedCounter;
enum
{
// The thread block size in threads
BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
RADIX_DIGITS = 1 << RADIX_BITS,
LOG_WARP_THREADS = Log2<ARCH>::VALUE,
WARP_THREADS = 1 << LOG_WARP_THREADS,
WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
BYTES_PER_COUNTER = sizeof(DigitCounter),
LOG_BYTES_PER_COUNTER = Log2<BYTES_PER_COUNTER>::VALUE,
PACKING_RATIO = sizeof(PackedCounter) / sizeof(DigitCounter),
LOG_PACKING_RATIO = Log2<PACKING_RATIO>::VALUE,
LOG_COUNTER_LANES = rocprim::maximum<int>()((int(RADIX_BITS) - int(LOG_PACKING_RATIO)), 0), // Always at least one lane
COUNTER_LANES = 1 << LOG_COUNTER_LANES,
// The number of packed counters per thread (plus one for padding)
PADDED_COUNTER_LANES = COUNTER_LANES + 1,
RAKING_SEGMENT = PADDED_COUNTER_LANES,
};
public:
enum
{
/// Number of bin-starting offsets tracked per thread
BINS_TRACKED_PER_THREAD = rocprim::maximum<int>()(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS),
};
private:
/// BlockScan type
typedef BlockScan<
PackedCounter,
BLOCK_DIM_X,
INNER_SCAN_ALGORITHM,
BLOCK_DIM_Y,
BLOCK_DIM_Z,
ARCH>
BlockScan;
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
/// Shared memory storage layout type for BlockRadixRank
struct __align__(16) _TempStorage
{
union Aliasable
{
DigitCounter digit_counters[PADDED_COUNTER_LANES * BLOCK_THREADS * PACKING_RATIO];
PackedCounter raking_grid[BLOCK_THREADS * RAKING_SEGMENT];
} aliasable;
// Storage for scanning local ranks
typename BlockScan::TempStorage block_scan;
};
#endif
/******************************************************************************
* Thread fields
******************************************************************************/
/// Shared storage reference
_TempStorage &temp_storage;
/// Linear thread-id
unsigned int linear_tid;
/// Copy of raking segment, promoted to registers
PackedCounter cached_segment[RAKING_SEGMENT];
/******************************************************************************
* Utility methods
******************************************************************************/
/**
* Internal storage allocator
*/
HIPCUB_DEVICE inline _TempStorage& PrivateStorage()
{
__shared__ _TempStorage private_storage;
return private_storage;
}
/**
* Performs upsweep raking reduction, returning the aggregate
*/
HIPCUB_DEVICE inline PackedCounter Upsweep()
{
PackedCounter *smem_raking_ptr = &temp_storage.aliasable.raking_grid[linear_tid * RAKING_SEGMENT];
PackedCounter *raking_ptr;
if (MEMOIZE_OUTER_SCAN)
{
// Copy data into registers
#pragma unroll
for (int i = 0; i < RAKING_SEGMENT; i++)
{
cached_segment[i] = smem_raking_ptr[i];
}
raking_ptr = cached_segment;
}
else
{
raking_ptr = smem_raking_ptr;
}
return internal::ThreadReduce<RAKING_SEGMENT>(raking_ptr, Sum());
}
/// Performs exclusive downsweep raking scan
HIPCUB_DEVICE inline void ExclusiveDownsweep(
PackedCounter raking_partial)
{
PackedCounter *smem_raking_ptr = &temp_storage.aliasable.raking_grid[linear_tid * RAKING_SEGMENT];
PackedCounter *raking_ptr = (MEMOIZE_OUTER_SCAN) ?
cached_segment :
smem_raking_ptr;
// Exclusive raking downsweep scan
internal::ThreadScanExclusive<RAKING_SEGMENT>(raking_ptr, raking_ptr, Sum(), raking_partial);
if (MEMOIZE_OUTER_SCAN)
{
// Copy data back to smem
#pragma unroll
for (int i = 0; i < RAKING_SEGMENT; i++)
{
smem_raking_ptr[i] = cached_segment[i];
}
}
}
/**
* Reset shared memory digit counters
*/
HIPCUB_DEVICE inline void ResetCounters()
{
// Reset shared memory digit counters
#pragma unroll
for (int LANE = 0; LANE < PADDED_COUNTER_LANES; LANE++)
{
#pragma unroll
for (int SUB_COUNTER = 0; SUB_COUNTER < PACKING_RATIO; SUB_COUNTER++)
{
temp_storage.aliasable.digit_counters[(LANE * BLOCK_THREADS + linear_tid) * PACKING_RATIO + SUB_COUNTER] = 0;
}
}
}
/**
* Block-scan prefix callback
*/
struct PrefixCallBack
{
HIPCUB_DEVICE inline PackedCounter operator()(PackedCounter block_aggregate)
{
PackedCounter block_prefix = 0;
// Propagate totals in packed fields
#pragma unroll
for (int PACKED = 1; PACKED < PACKING_RATIO; PACKED++)
{
block_prefix += block_aggregate << (sizeof(DigitCounter) * 8 * PACKED);
}
return block_prefix;
}
};
/**
* Scan shared memory digit counters.
*/
HIPCUB_DEVICE inline void ScanCounters()
{
// Upsweep scan
PackedCounter raking_partial = Upsweep();
// Compute exclusive sum
PackedCounter exclusive_partial;
PrefixCallBack prefix_call_back;
BlockScan(temp_storage.block_scan).ExclusiveSum(raking_partial, exclusive_partial, prefix_call_back);
// Downsweep scan with exclusive partial
ExclusiveDownsweep(exclusive_partial);
}
public:
/// \smemstorage{BlockScan}
struct TempStorage : Uninitialized<_TempStorage> {};
/******************************************************************//**
* \name Collective constructors
*********************************************************************/
//@{
/**
* \brief Collective constructor using a private static allocation of shared memory as temporary storage.
*/
HIPCUB_DEVICE inline BlockRadixRank()
:
temp_storage(PrivateStorage()),
linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
{}
/**
* \brief Collective constructor using the specified memory allocation as temporary storage.
*/
HIPCUB_DEVICE inline BlockRadixRank(
TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage
:
temp_storage(temp_storage.Alias()),
linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
{}
//@} end member group
/******************************************************************//**
* \name Raking
*********************************************************************/
//@{
/**
* \brief Rank keys.
*/
template <
typename UnsignedBits,
int KEYS_PER_THREAD,
typename DigitExtractorT>
HIPCUB_DEVICE inline void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile
DigitExtractorT digit_extractor) ///< [in] The digit extractor
{
DigitCounter thread_prefixes[KEYS_PER_THREAD]; // For each key, the count of previous keys in this tile having the same digit
DigitCounter* digit_counters[KEYS_PER_THREAD]; // For each key, the byte-offset of its corresponding digit counter in smem
// Reset shared memory digit counters
ResetCounters();
#pragma unroll
for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
{
// Get digit
unsigned int digit = digit_extractor.Digit(keys[ITEM]);
// Get sub-counter
unsigned int sub_counter = digit >> LOG_COUNTER_LANES;
// Get counter lane
unsigned int counter_lane = digit & (COUNTER_LANES - 1);
if (IS_DESCENDING)
{
sub_counter = PACKING_RATIO - 1 - sub_counter;
counter_lane = COUNTER_LANES - 1 - counter_lane;
}
// Pointer to smem digit counter
digit_counters[ITEM] = &temp_storage.aliasable.digit_counters[counter_lane * BLOCK_THREADS * PACKING_RATIO + linear_tid * PACKING_RATIO + sub_counter];
// Load thread-exclusive prefix
thread_prefixes[ITEM] = *digit_counters[ITEM];
// Store inclusive prefix
*digit_counters[ITEM] = thread_prefixes[ITEM] + 1;
}
::rocprim::syncthreads();
// Scan shared memory counters
ScanCounters();
::rocprim::syncthreads();
// Extract the local ranks of each key
#pragma unroll
for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
{
// Add in thread block exclusive prefix
ranks[ITEM] = thread_prefixes[ITEM] + *digit_counters[ITEM];
}
}
/**
* \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread.
*/
template <
typename UnsignedBits,
int KEYS_PER_THREAD,
typename DigitExtractorT>
HIPCUB_DEVICE inline void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter)
DigitExtractorT digit_extractor, ///< [in] The digit extractor
int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1]
{
// Rank keys
RankKeys(keys, ranks, digit_extractor);
// Get the inclusive and exclusive digit totals corresponding to the calling thread.
#pragma unroll
for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
{
int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track;
if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
{
if (IS_DESCENDING)
bin_idx = RADIX_DIGITS - bin_idx - 1;
// Obtain ex/inclusive digit counts. (Unfortunately these all reside in the
// first counter column, resulting in unavoidable bank conflicts.)
unsigned int counter_lane = (bin_idx & (COUNTER_LANES - 1));
unsigned int sub_counter = bin_idx >> (LOG_COUNTER_LANES);
exclusive_digit_prefix[track] = temp_storage.aliasable.digit_counter[counter_lane * BLOCK_THREADS * PACKING_RATIO + sub_counter];
}
}
}
};
/**
* Radix-rank using match.any
*/
template <
int BLOCK_DIM_X,
int RADIX_BITS,
bool IS_DESCENDING,
BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1,
int ARCH = HIPCUB_ARCH>
class BlockRadixRankMatch
{
private:
/******************************************************************************
* Type definitions and constants
******************************************************************************/
typedef int32_t RankT;
typedef int32_t DigitCounterT;
enum
{
// The thread block size in threads
BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
RADIX_DIGITS = 1 << RADIX_BITS,
LOG_WARP_THREADS = Log2<ARCH>::VALUE,
WARP_THREADS = 1 << LOG_WARP_THREADS,
WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
PADDED_WARPS = ((WARPS & 0x1) == 0) ?
WARPS + 1 :
WARPS,
COUNTERS = PADDED_WARPS * RADIX_DIGITS,
RAKING_SEGMENT = (COUNTERS + BLOCK_THREADS - 1) / BLOCK_THREADS,
PADDED_RAKING_SEGMENT = ((RAKING_SEGMENT & 0x1) == 0) ?
RAKING_SEGMENT + 1 :
RAKING_SEGMENT,
};
public:
enum
{
/// Number of bin-starting offsets tracked per thread
BINS_TRACKED_PER_THREAD = rocprim::maximum<int>()(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS),
};
private:
/// BlockScan type
typedef BlockScan<
DigitCounterT,
BLOCK_THREADS,
INNER_SCAN_ALGORITHM,
BLOCK_DIM_Y,
BLOCK_DIM_Z,
ARCH>
BlockScanT;
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
/// Shared memory storage layout type for BlockRadixRank
struct __align__(16) _TempStorage
{
typename BlockScanT::TempStorage block_scan;
union __align__(16) Aliasable
{
volatile DigitCounterT warp_digit_counters[RADIX_DIGITS * PADDED_WARPS];
DigitCounterT raking_grid[BLOCK_THREADS * PADDED_RAKING_SEGMENT];
} aliasable;
};
#endif
/******************************************************************************
* Thread fields
******************************************************************************/
/// Shared storage reference
_TempStorage &temp_storage;
/// Linear thread-id
unsigned int linear_tid;
public:
/// \smemstorage{BlockScan}
struct TempStorage : Uninitialized<_TempStorage> {};
/******************************************************************//**
* \name Collective constructors
*********************************************************************/
//@{
/**
* \brief Collective constructor using the specified memory allocation as temporary storage.
*/
HIPCUB_DEVICE inline BlockRadixRankMatch(
TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage
:
temp_storage(temp_storage.Alias()),
linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
{}
//@} end member group
/******************************************************************//**
* \name Raking
*********************************************************************/
//@{
/**
* \brief Rank keys.
*/
template <
typename UnsignedBits,
int KEYS_PER_THREAD,
typename DigitExtractorT>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile
DigitExtractorT digit_extractor) ///< [in] The digit extractor
{
// Initialize shared digit counters
#pragma unroll
for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
temp_storage.aliasable.raking_grid[linear_tid * PADDED_RAKING_SEGMENT + ITEM] = 0;
::rocprim::syncthreads();
// Each warp will strip-mine its section of input, one strip at a time
volatile DigitCounterT *digit_counters[KEYS_PER_THREAD];
uint32_t warp_id = linear_tid >> LOG_WARP_THREADS;
uint32_t lane_mask_lt = LaneMaskLt();
#pragma unroll
for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
{
// My digit
uint32_t digit = digit_extractor.Digit(keys[ITEM]);
if (IS_DESCENDING)
digit = RADIX_DIGITS - digit - 1;
// Mask of peers who have same digit as me
uint32_t peer_mask = rocprim::MatchAny<RADIX_BITS>(digit);
// Pointer to smem digit counter for this key
digit_counters[ITEM] = &temp_storage.aliasable.warp_digit_counters[digit * PADDED_WARPS + warp_id];
// Number of occurrences in previous strips
DigitCounterT warp_digit_prefix = *digit_counters[ITEM];
// Warp-sync
WARP_SYNC(0xFFFFFFFF);
// Number of peers having same digit as me
int32_t digit_count = __popc(peer_mask);
// Number of lower-ranked peers having same digit seen so far
int32_t peer_digit_prefix = __popc(peer_mask & lane_mask_lt);
if (peer_digit_prefix == 0)
{
// First thread for each digit updates the shared warp counter
*digit_counters[ITEM] = DigitCounterT(warp_digit_prefix + digit_count);
}
// Warp-sync
WARP_SYNC(0xFFFFFFFF);
// Number of prior keys having same digit
ranks[ITEM] = warp_digit_prefix + DigitCounterT(peer_digit_prefix);
}
::rocprim::syncthreads();
// Scan warp counters
DigitCounterT scan_counters[PADDED_RAKING_SEGMENT];
#pragma unroll
for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
scan_counters[ITEM] = temp_storage.aliasable.raking_grid[linear_tid * PADDED_RAKING_SEGMENT + ITEM];
BlockScanT(temp_storage.block_scan).ExclusiveSum(scan_counters, scan_counters);
#pragma unroll
for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
temp_storage.aliasable.raking_grid[linear_tid * PADDED_RAKING_SEGMENT + ITEM] = scan_counters[ITEM];
::rocprim::syncthreads();
// Seed ranks with counter values from previous warps
#pragma unroll
for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
ranks[ITEM] += *digit_counters[ITEM];
}
/**
* \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread.
*/
template <
typename UnsignedBits,
int KEYS_PER_THREAD,
typename DigitExtractorT>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter)
DigitExtractorT digit_extractor, ///< [in] The digit extractor
int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1]
{
RankKeys(keys, ranks, digit_extractor);
// Get exclusive count for each digit
#pragma unroll
for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
{
int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track;
if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
{
if (IS_DESCENDING)
bin_idx = RADIX_DIGITS - bin_idx - 1;
exclusive_digit_prefix[track] = temp_storage.aliasable.warp_digit_counters[bin_idx * PADDED_WARPS];
}
}
}
};
END_HIPCUB_NAMESPACE
#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_
/******************************************************************************
* Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_
#define HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_
#include "../config.hpp"
#include "../util_type.cuh"
#include <cub/rocprim/functional.hpp>
#include <cub/rocprim/block/block_radix_sort.hpp>
#include "block_scan.cuh"
BEGIN_HIPCUB_NAMESPACE
template<
typename KeyT,
int BLOCK_DIM_X,
int ITEMS_PER_THREAD,
typename ValueT = NullType,
int RADIX_BITS = 4, /* ignored */
bool MEMOIZE_OUTER_SCAN = true, /* ignored */
BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, /* ignored */
cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte, /* ignored */
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1,
int PTX_ARCH = HIPCUB_ARCH /* ignored */
>
class BlockRadixSort
: private ::rocprim::block_radix_sort<
KeyT,
BLOCK_DIM_X,
ITEMS_PER_THREAD,
ValueT,
BLOCK_DIM_Y,
BLOCK_DIM_Z
>
{
static_assert(
BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
"BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0"
);
using base_type =
typename ::rocprim::block_radix_sort<
KeyT,
BLOCK_DIM_X,
ITEMS_PER_THREAD,
ValueT,
BLOCK_DIM_Y,
BLOCK_DIM_Z
>;
// Reference to temporary storage (usually shared memory)
typename base_type::storage_type& temp_storage_;
public:
using TempStorage = typename base_type::storage_type;
HIPCUB_DEVICE inline
BlockRadixSort() : temp_storage_(private_storage())
{
}
HIPCUB_DEVICE inline
BlockRadixSort(TempStorage& temp_storage) : temp_storage_(temp_storage)
{
}
HIPCUB_DEVICE inline
void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8)
{
base_type::sort(keys, temp_storage_, begin_bit, end_bit);
}
HIPCUB_DEVICE inline
void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&values)[ITEMS_PER_THREAD],
int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8)
{
base_type::sort(keys, values, temp_storage_, begin_bit, end_bit);
}
HIPCUB_DEVICE inline
void SortDescending(KeyT (&keys)[ITEMS_PER_THREAD],
int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8)
{
base_type::sort_desc(keys, temp_storage_, begin_bit, end_bit);
}
HIPCUB_DEVICE inline
void SortDescending(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&values)[ITEMS_PER_THREAD],
int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8)
{
base_type::sort_desc(keys, values, temp_storage_, begin_bit, end_bit);
}
HIPCUB_DEVICE inline
void SortBlockedToStriped(KeyT (&keys)[ITEMS_PER_THREAD],
int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8)
{
base_type::sort_to_striped(keys, temp_storage_, begin_bit, end_bit);
}
HIPCUB_DEVICE inline
void SortBlockedToStriped(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&values)[ITEMS_PER_THREAD],
int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8)
{
base_type::sort_to_striped(keys, values, temp_storage_, begin_bit, end_bit);
}
HIPCUB_DEVICE inline
void SortDescendingBlockedToStriped(KeyT (&keys)[ITEMS_PER_THREAD],
int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8)
{
base_type::sort_desc_to_striped(keys, temp_storage_, begin_bit, end_bit);
}
HIPCUB_DEVICE inline
void SortDescendingBlockedToStriped(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&values)[ITEMS_PER_THREAD],
int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8)
{
base_type::sort_desc_to_striped(keys, values, temp_storage_, begin_bit, end_bit);
}
private:
HIPCUB_DEVICE inline
TempStorage& private_storage()
{
HIPCUB_SHARED_MEMORY TempStorage private_storage;
return private_storage;
}
};
END_HIPCUB_NAMESPACE
#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_
/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/**
* \file
* cub::BlockRakingLayout provides a conflict-free shared memory layout abstraction for warp-raking across thread block data.
*/
#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_RAKING_LAYOUT_HPP_
#define HIPCUB_ROCPRIM_BLOCK_BLOCK_RAKING_LAYOUT_HPP_
#include <type_traits>
#include "../config.hpp"
#include <cub/rocprim/config.hpp>
#include <cub/rocprim/detail/various.hpp>
BEGIN_HIPCUB_NAMESPACE
/**
* \brief BlockRakingLayout provides a conflict-free shared memory layout abstraction for 1D raking across thread block data.
* \ingroup BlockModule
*
* \par Overview
* This type facilitates a shared memory usage pattern where a block of CUDA
* threads places elements into shared memory and then reduces the active
* parallelism to one "raking" warp of threads for serially aggregating consecutive
* sequences of shared items. Padding is inserted to eliminate bank conflicts
* (for most data types).
*
* \tparam T The data type to be exchanged.
* \tparam BLOCK_THREADS The thread block size in threads.
* \tparam PTX_ARCH <b>[optional]</b> \ptxversion
*/
template <
typename T,
int BLOCK_THREADS,
int ARCH = HIPCUB_ARCH /* ignored */
>
struct block_raking_layout
{
//---------------------------------------------------------------------
// Constants and type definitions
//---------------------------------------------------------------------
enum
{
/// The total number of elements that need to be cooperatively reduced
SHARED_ELEMENTS = BLOCK_THREADS,
/// Maximum number of warp-synchronous raking threads
MAX_RAKING_THREADS = ::rocprim::detail::get_min_warp_size(BLOCK_THREADS, HIPCUB_DEVICE_WARP_THREADS),
/// Number of raking elements per warp-synchronous raking thread (rounded up)
SEGMENT_LENGTH = (SHARED_ELEMENTS + MAX_RAKING_THREADS - 1) / MAX_RAKING_THREADS,
/// Never use a raking thread that will have no valid data (e.g., when BLOCK_THREADS is 62 and SEGMENT_LENGTH is 2, we should only use 31 raking threads)
RAKING_THREADS = (SHARED_ELEMENTS + SEGMENT_LENGTH - 1) / SEGMENT_LENGTH,
/// Pad each segment length with one element if segment length is not relatively prime to warp size and can't be optimized as a vector load
USE_SEGMENT_PADDING = ((SEGMENT_LENGTH & 1) == 0) && (SEGMENT_LENGTH > 2),
/// Total number of elements in the raking grid
GRID_ELEMENTS = RAKING_THREADS * (SEGMENT_LENGTH + USE_SEGMENT_PADDING),
/// Whether or not we need bounds checking during raking (the number of reduction elements is not a multiple of the number of raking threads)
UNGUARDED = (SHARED_ELEMENTS % RAKING_THREADS == 0),
};
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
/**
* \brief Shared memory storage type
*/
struct __align__(16) _TempStorage
{
T buff[BlockRakingLayout::GRID_ELEMENTS];
};
#endif
/// Alias wrapper allowing storage to be unioned
struct TempStorage : Uninitialized<_TempStorage> {};
/**
* \brief Returns the location for the calling thread to place data into the grid
*/
static HIPCUB_DEVICE inline T* PlacementPtr(
TempStorage &temp_storage,
unsigned int linear_tid)
{
// Offset for partial
unsigned int offset = linear_tid;
// Add in one padding element for every segment
if (USE_SEGMENT_PADDING > 0)
{
offset += offset / SEGMENT_LENGTH;
}
// Incorporating a block of padding partials every shared memory segment
return temp_storage.Alias().buff + offset;
}
/**
* \brief Returns the location for the calling thread to begin sequential raking
*/
static HIPCUB_DEVICE inline T* RakingPtr(
TempStorage &temp_storage,
unsigned int linear_tid)
{
return temp_storage.Alias().buff + (linear_tid * (SEGMENT_LENGTH + USE_SEGMENT_PADDING));
}
};
END_HIPCUB_NAMESPACE
#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_RAKING_LAYOUT_HPP_
/******************************************************************************
* Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_REDUCE_HPP_
#define HIPCUB_ROCPRIM_BLOCK_BLOCK_REDUCE_HPP_
#include <type_traits>
#include <cub/rocprim/block/block_reduce.hpp>
BEGIN_HIPCUB_NAMESPACE
namespace detail
{
inline constexpr
typename std::underlying_type<::rocprim::block_reduce_algorithm>::type
to_BlockReduceAlgorithm_enum(::rocprim::block_reduce_algorithm v)
{
using utype = std::underlying_type<::rocprim::block_reduce_algorithm>::type;
return static_cast<utype>(v);
}
}
enum BlockReduceAlgorithm
{
BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
= detail::to_BlockReduceAlgorithm_enum(::rocprim::block_reduce_algorithm::raking_reduce_commutative_only),
BLOCK_REDUCE_RAKING
= detail::to_BlockReduceAlgorithm_enum(::rocprim::block_reduce_algorithm::raking_reduce),
BLOCK_REDUCE_WARP_REDUCTIONS
= detail::to_BlockReduceAlgorithm_enum(::rocprim::block_reduce_algorithm::using_warp_reduce)
};
template<
typename T,
int BLOCK_DIM_X,
BlockReduceAlgorithm ALGORITHM = BLOCK_REDUCE_WARP_REDUCTIONS,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1,
int ARCH = HIPCUB_ARCH /* ignored */
>
class BlockReduce
: private ::rocprim::block_reduce<
T,
BLOCK_DIM_X,
static_cast<::rocprim::block_reduce_algorithm>(ALGORITHM),
BLOCK_DIM_Y,
BLOCK_DIM_Z
>
{
static_assert(
BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
"BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0"
);
using base_type =
typename ::rocprim::block_reduce<
T,
BLOCK_DIM_X,
static_cast<::rocprim::block_reduce_algorithm>(ALGORITHM),
BLOCK_DIM_Y,
BLOCK_DIM_Z
>;
// Reference to temporary storage (usually shared memory)
typename base_type::storage_type& temp_storage_;
public:
using TempStorage = typename base_type::storage_type;
HIPCUB_DEVICE inline
BlockReduce() : temp_storage_(private_storage())
{
}
HIPCUB_DEVICE inline
BlockReduce(TempStorage& temp_storage) : temp_storage_(temp_storage)
{
}
HIPCUB_DEVICE inline
T Sum(T input)
{
base_type::reduce(input, input, temp_storage_);
return input;
}
HIPCUB_DEVICE inline
T Sum(T input, int valid_items)
{
base_type::reduce(input, input, valid_items, temp_storage_);
return input;
}
template<int ITEMS_PER_THREAD>
HIPCUB_DEVICE inline
T Sum(T(&input)[ITEMS_PER_THREAD])
{
T output;
base_type::reduce(input, output, temp_storage_);
return output;
}
template<typename ReduceOp>
HIPCUB_DEVICE inline
T Reduce(T input, ReduceOp reduce_op)
{
base_type::reduce(input, input, temp_storage_, reduce_op);
return input;
}
template<typename ReduceOp>
HIPCUB_DEVICE inline
T Reduce(T input, ReduceOp reduce_op, int valid_items)
{
base_type::reduce(input, input, valid_items, temp_storage_, reduce_op);
return input;
}
template<int ITEMS_PER_THREAD, typename ReduceOp>
HIPCUB_DEVICE inline
T Reduce(T(&input)[ITEMS_PER_THREAD], ReduceOp reduce_op)
{
T output;
base_type::reduce(input, output, temp_storage_, reduce_op);
return output;
}
private:
HIPCUB_DEVICE inline
TempStorage& private_storage()
{
HIPCUB_SHARED_MEMORY TempStorage private_storage;
return private_storage;
}
};
END_HIPCUB_NAMESPACE
#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_REDUCE_HPP_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment