Unverified Commit 880b3b1f authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Fix] Enable lint check for cuh files and fix compiler warnings (#4585)

* disable warning for tensorpipe

* fix warning

* enable lint check for cuh files

* resolve comments
parent 166b273b
...@@ -4,21 +4,20 @@ ...@@ -4,21 +4,20 @@
* \brief frequency hashmap - used to select top-k frequency edges of each node * \brief frequency hashmap - used to select top-k frequency edges of each node
*/ */
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_CUH_ #ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_CUH_ #define DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <tuple>
namespace dgl { namespace dgl {
namespace sampling { namespace sampling {
namespace impl { namespace impl {
template<typename IdxType> template<typename IdxType>
class DeviceEdgeHashmap { class DeviceEdgeHashmap {
public: public:
struct EdgeItem { struct EdgeItem {
IdxType src; IdxType src;
IdxType cnt; IdxType cnt;
...@@ -27,13 +26,13 @@ public: ...@@ -27,13 +26,13 @@ public:
DeviceEdgeHashmap(int64_t num_dst, int64_t num_items_each_dst, DeviceEdgeHashmap(int64_t num_dst, int64_t num_items_each_dst,
IdxType* dst_unique_edges, EdgeItem *edge_hashmap): IdxType* dst_unique_edges, EdgeItem *edge_hashmap):
_num_dst(num_dst), _num_items_each_dst(num_items_each_dst), _num_dst(num_dst), _num_items_each_dst(num_items_each_dst),
_dst_unique_edges(dst_unique_edges), _edge_hashmap(edge_hashmap) {}; _dst_unique_edges(dst_unique_edges), _edge_hashmap(edge_hashmap) {}
// return the old cnt of this edge // return the old cnt of this edge
inline __device__ IdxType InsertEdge(const IdxType &src, const IdxType &dst_idx); inline __device__ IdxType InsertEdge(const IdxType &src, const IdxType &dst_idx);
inline __device__ IdxType GetDstCount(const IdxType &dst_idx); inline __device__ IdxType GetDstCount(const IdxType &dst_idx);
inline __device__ IdxType GetEdgeCount(const IdxType &src, const IdxType &dst_idx); inline __device__ IdxType GetEdgeCount(const IdxType &src, const IdxType &dst_idx);
private: private:
int64_t _num_dst; int64_t _num_dst;
int64_t _num_items_each_dst; int64_t _num_items_each_dst;
IdxType *_dst_unique_edges; IdxType *_dst_unique_edges;
...@@ -41,12 +40,12 @@ private: ...@@ -41,12 +40,12 @@ private:
inline __device__ IdxType EdgeHash(const IdxType &id) const { inline __device__ IdxType EdgeHash(const IdxType &id) const {
return id % _num_items_each_dst; return id % _num_items_each_dst;
}; }
}; };
template<typename IdxType> template<typename IdxType>
class FrequencyHashmap { class FrequencyHashmap {
public: public:
static constexpr int64_t kDefaultEdgeTableScale = 3; static constexpr int64_t kDefaultEdgeTableScale = 3;
FrequencyHashmap() = delete; FrequencyHashmap() = delete;
FrequencyHashmap(int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx, cudaStream_t stream, FrequencyHashmap(int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx, cudaStream_t stream,
...@@ -57,7 +56,7 @@ public: ...@@ -57,7 +56,7 @@ public:
const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype, const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype,
const int64_t num_edges, const int64_t num_edges_per_node, const int64_t num_edges, const int64_t num_edges_per_node,
const int64_t num_pick); const int64_t num_pick);
private: private:
DGLContext _ctx; DGLContext _ctx;
cudaStream_t _stream; cudaStream_t _stream;
DeviceEdgeHashmap<IdxType> *_device_edge_hashmap; DeviceEdgeHashmap<IdxType> *_device_edge_hashmap;
...@@ -66,9 +65,7 @@ private: ...@@ -66,9 +65,7 @@ private:
}; };
}; // namespace impl }; // namespace impl
}; // namespace sampling }; // namespace sampling
}; // namespace dgl }; // namespace dgl
#endif // DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_CUH_ #endif // DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_
...@@ -69,7 +69,7 @@ CompactGraphsCPU( ...@@ -69,7 +69,7 @@ CompactGraphsCPU(
} }
// Reserve the space for hash maps before ahead to aoivd rehashing // Reserve the space for hash maps before ahead to aoivd rehashing
for (size_t i = 0; i < num_ntypes; ++i) { for (size_t i = 0; i < static_cast<size_t>(num_ntypes); ++i) {
if (i < always_preserve.size()) if (i < always_preserve.size())
hashmaps[i].Reserve(always_preserve[i]->shape[0] + max_vertex_cnt[i]); hashmaps[i].Reserve(always_preserve[i]->shape[0] + max_vertex_cnt[i]);
else else
......
...@@ -402,7 +402,7 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -402,7 +402,7 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
// randomly select neighbors as candidates // randomly select neighbors as candidates
int num_threads = omp_get_max_threads(); int num_threads = omp_get_max_threads();
runtime::parallel_for(0, num_threads, [&](size_t b, size_t e) { runtime::parallel_for(0, num_threads, [&](IdType b, IdType e) {
for (auto tid = b; tid < e; ++tid) { for (auto tid = b; tid < e; ++tid) {
for (IdType i = point_idx_start; i < point_idx_end; ++i) { for (IdType i = point_idx_start; i < point_idx_end; ++i) {
IdType local_idx = i - point_idx_start; IdType local_idx = i - point_idx_start;
......
...@@ -22,6 +22,11 @@ ...@@ -22,6 +22,11 @@
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <algorithm>
#include <memory>
#include <tuple>
#include <vector>
#include <utility>
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "../../../runtime/cuda/cuda_hashtable.cuh" #include "../../../runtime/cuda/cuda_hashtable.cuh"
...@@ -265,4 +270,4 @@ MapEdges( ...@@ -265,4 +270,4 @@ MapEdges(
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
#endif #endif // DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_
...@@ -68,7 +68,7 @@ static void StringAppendV(string* dst, const char* format, va_list ap) { ...@@ -68,7 +68,7 @@ static void StringAppendV(string* dst, const char* format, va_list ap) {
int result = vsnprintf(space, sizeof(space), format, backup_ap); int result = vsnprintf(space, sizeof(space), format, backup_ap);
va_end(backup_ap); va_end(backup_ap);
if ((result >= 0) && (result < sizeof(space))) { if ((result >= 0) && (result < static_cast<int>(sizeof(space)))) {
// It fit // It fit
dst->append(space, result); dst->append(space, result);
return; return;
......
...@@ -320,15 +320,6 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage") ...@@ -320,15 +320,6 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
*rv = rst; *rv = rst;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessageWithSize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t message_size = args[0];
std::shared_ptr<RPCMessage> rst(new RPCMessage);
*rv = rst;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<RPCMessage> rst(new RPCMessage); std::shared_ptr<RPCMessage> rst(new RPCMessage);
...@@ -482,7 +473,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") ...@@ -482,7 +473,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
dgl_id_t idx = 0; dgl_id_t idx = 0;
for (dgl_id_t i = 0; i < ID_size; ++i) { for (dgl_id_t i = 0; i < ID_size; ++i) {
dgl_id_t p_id = part_id_data[i]; dgl_id_t p_id = part_id_data[i];
if (p_id == local_machine_id) { if (static_cast<int>(p_id) == local_machine_id) {
dgl_id_t l_id = local_id_data[idx++]; dgl_id_t l_id = local_id_data[idx++];
CHECK_LT(l_id, local_data_shape[0]); CHECK_LT(l_id, local_data_shape[0]);
CHECK_GE(l_id, 0); CHECK_GE(l_id, 0);
...@@ -497,7 +488,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") ...@@ -497,7 +488,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
} }
// Send remote id // Send remote id
int msg_count = 0; int msg_count = 0;
for (int i = 0; i < remote_ids.size(); ++i) { for (size_t i = 0; i < remote_ids.size(); ++i) {
if (remote_ids[i].size() != 0) { if (remote_ids[i].size() != 0) {
RPCMessage msg; RPCMessage msg;
msg.service_id = service_id; msg.service_id = service_id;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <deque> #include <deque>
#include <mutex> #include <mutex>
#include <chrono> #include <chrono>
#include <utility>
namespace dgl { namespace dgl {
namespace rpc { namespace rpc {
......
...@@ -64,128 +64,128 @@ class OrderedHashTable; ...@@ -64,128 +64,128 @@ class OrderedHashTable;
*/ */
template<typename IdType> template<typename IdType>
class DeviceOrderedHashTable { class DeviceOrderedHashTable {
public: public:
/**
* \brief An entry in the hashtable.
*/
struct Mapping {
/** /**
* \brief An entry in the hashtable. * \brief The ID of the item inserted.
*/
struct Mapping {
/**
* \brief The ID of the item inserted.
*/
IdType key;
/**
* \brief The index of the item in the unique list.
*/
IdType local;
/**
* \brief The index of the item when inserted into the hashtable (e.g.,
* the index within the array passed into FillWithDuplicates()).
*/
int64_t index;
};
typedef const Mapping* ConstIterator;
DeviceOrderedHashTable(
const DeviceOrderedHashTable& other) = default;
DeviceOrderedHashTable& operator=(
const DeviceOrderedHashTable& other) = default;
/**
* \brief Find the non-mutable mapping of a given key within the hash table.
*
* WARNING: The key must exist within the hashtable. Searching for a key not
* in the hashtable is undefined behavior.
*
* \param id The key to search for.
*
* \return An iterator to the mapping.
*/
inline __device__ ConstIterator Search(
const IdType id) const {
const IdType pos = SearchForPosition(id);
return &table_[pos];
}
/**
* \brief Check whether a key exists within the hashtable.
*
* \param id The key to check for.
*
* \return True if the key exists in the hashtable.
*/ */
inline __device__ bool Contains( IdType key;
const IdType id) const {
IdType pos = Hash(id);
IdType delta = 1;
while (table_[pos].key != kEmptyKey) {
if (table_[pos].key == id) {
return true;
}
pos = Hash(pos+delta);
delta +=1;
}
return false;
}
protected:
// Must be uniform bytes for memset to work
static constexpr IdType kEmptyKey = static_cast<IdType>(-1);
const Mapping * table_;
size_t size_;
/** /**
* \brief Create a new device-side handle to the hash table. * \brief The index of the item in the unique list.
* */
* \param table The table stored in GPU memory. IdType local;
* \param size The size of the table.
*/
explicit DeviceOrderedHashTable(
const Mapping * table,
size_t size);
/** /**
* \brief Search for an item in the hash table which is known to exist. * \brief The index of the item when inserted into the hashtable (e.g.,
* * the index within the array passed into FillWithDuplicates()).
* WARNING: If the ID searched for does not exist within the hashtable, this
* function will never return.
*
* \param id The ID of the item to search for.
*
* \return The the position of the item in the hashtable.
*/ */
inline __device__ IdType SearchForPosition( int64_t index;
const IdType id) const { };
IdType pos = Hash(id);
typedef const Mapping* ConstIterator;
// linearly scan for matching entry
IdType delta = 1; DeviceOrderedHashTable(
while (table_[pos].key != id) { const DeviceOrderedHashTable& other) = default;
assert(table_[pos].key != kEmptyKey); DeviceOrderedHashTable& operator=(
pos = Hash(pos+delta); const DeviceOrderedHashTable& other) = default;
delta +=1;
/**
* \brief Find the non-mutable mapping of a given key within the hash table.
*
* WARNING: The key must exist within the hashtable. Searching for a key not
* in the hashtable is undefined behavior.
*
* \param id The key to search for.
*
* \return An iterator to the mapping.
*/
inline __device__ ConstIterator Search(
const IdType id) const {
const IdType pos = SearchForPosition(id);
return &table_[pos];
}
/**
* \brief Check whether a key exists within the hashtable.
*
* \param id The key to check for.
*
* \return True if the key exists in the hashtable.
*/
inline __device__ bool Contains(
const IdType id) const {
IdType pos = Hash(id);
IdType delta = 1;
while (table_[pos].key != kEmptyKey) {
if (table_[pos].key == id) {
return true;
} }
assert(pos < size_); pos = Hash(pos+delta);
delta +=1;
return pos;
} }
return false;
/** }
* \brief Hash an ID to a to a position in the hash table.
* protected:
* \param id The ID to hash. // Must be uniform bytes for memset to work
* static constexpr IdType kEmptyKey = static_cast<IdType>(-1);
* \return The hash.
*/ const Mapping * table_;
inline __device__ size_t Hash( size_t size_;
const IdType id) const {
return id % size_; /**
* \brief Create a new device-side handle to the hash table.
*
* \param table The table stored in GPU memory.
* \param size The size of the table.
*/
explicit DeviceOrderedHashTable(
const Mapping * table,
size_t size);
/**
* \brief Search for an item in the hash table which is known to exist.
*
* WARNING: If the ID searched for does not exist within the hashtable, this
* function will never return.
*
* \param id The ID of the item to search for.
*
* \return The the position of the item in the hashtable.
*/
inline __device__ IdType SearchForPosition(
const IdType id) const {
IdType pos = Hash(id);
// linearly scan for matching entry
IdType delta = 1;
while (table_[pos].key != id) {
assert(table_[pos].key != kEmptyKey);
pos = Hash(pos+delta);
delta +=1;
} }
assert(pos < size_);
friend class OrderedHashTable<IdType>;
return pos;
}
/**
* \brief Hash an ID to a to a position in the hash table.
*
* \param id The ID to hash.
*
* \return The hash.
*/
inline __device__ size_t Hash(
const IdType id) const {
return id % size_;
}
friend class OrderedHashTable<IdType>;
}; };
/*! /*!
...@@ -221,84 +221,83 @@ class DeviceOrderedHashTable { ...@@ -221,84 +221,83 @@ class DeviceOrderedHashTable {
*/ */
template<typename IdType> template<typename IdType>
class OrderedHashTable { class OrderedHashTable {
public: public:
static constexpr int kDefaultScale = 3; static constexpr int kDefaultScale = 3;
using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping; using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;
/** /**
* \brief Create a new ordered hash table. The amoutn of GPU memory * \brief Create a new ordered hash table. The amoutn of GPU memory
* consumed by the resulting hashtable is O(`size` * 2^`scale`). * consumed by the resulting hashtable is O(`size` * 2^`scale`).
* *
* \param size The number of items to insert into the hashtable. * \param size The number of items to insert into the hashtable.
* \param ctx The device context to store the hashtable on. * \param ctx The device context to store the hashtable on.
* \param scale The power of two times larger the number of buckets should * \param scale The power of two times larger the number of buckets should
* be than the number of items. * be than the number of items.
* \param stream The stream to use for initializing the hashtable. * \param stream The stream to use for initializing the hashtable.
*/ */
OrderedHashTable( OrderedHashTable(
const size_t size, const size_t size,
DGLContext ctx, DGLContext ctx,
cudaStream_t stream, cudaStream_t stream,
const int scale = kDefaultScale); const int scale = kDefaultScale);
/** /**
* \brief Cleanup after the hashtable. * \brief Cleanup after the hashtable.
*/ */
~OrderedHashTable(); ~OrderedHashTable();
// Disable copying // Disable copying
OrderedHashTable( OrderedHashTable(
const OrderedHashTable& other) = delete; const OrderedHashTable& other) = delete;
OrderedHashTable& operator=( OrderedHashTable& operator=(
const OrderedHashTable& other) = delete; const OrderedHashTable& other) = delete;
/** /**
* \brief Fill the hashtable with the array containing possibly duplicate * \brief Fill the hashtable with the array containing possibly duplicate
* IDs. * IDs.
* *
* \param input The array of IDs to insert. * \param input The array of IDs to insert.
* \param num_input The number of IDs to insert. * \param num_input The number of IDs to insert.
* \param unique The list of unique IDs inserted. * \param unique The list of unique IDs inserted.
* \param num_unique The number of unique IDs inserted. * \param num_unique The number of unique IDs inserted.
* \param stream The stream to perform operations on. * \param stream The stream to perform operations on.
*/ */
void FillWithDuplicates( void FillWithDuplicates(
const IdType * const input, const IdType * const input,
const size_t num_input, const size_t num_input,
IdType * const unique, IdType * const unique,
int64_t * const num_unique, int64_t * const num_unique,
cudaStream_t stream); cudaStream_t stream);
/** /**
* \brief Fill the hashtable with an array of unique keys. * \brief Fill the hashtable with an array of unique keys.
* *
* \param input The array of unique IDs. * \param input The array of unique IDs.
* \param num_input The number of keys. * \param num_input The number of keys.
* \param stream The stream to perform operations on. * \param stream The stream to perform operations on.
*/ */
void FillWithUnique( void FillWithUnique(
const IdType * const input, const IdType * const input,
const size_t num_input, const size_t num_input,
cudaStream_t stream); cudaStream_t stream);
/** /**
* \brief Get a verison of the hashtable usable from device functions. * \brief Get a verison of the hashtable usable from device functions.
* *
* \return This hashtable. * \return This hashtable.
*/ */
DeviceOrderedHashTable<IdType> DeviceHandle() const; DeviceOrderedHashTable<IdType> DeviceHandle() const;
private: private:
Mapping * table_; Mapping * table_;
size_t size_; size_t size_;
DGLContext ctx_; DGLContext ctx_;
}; };
} // cuda } // namespace cuda
} // runtime } // namespace runtime
} // dgl } // namespace dgl
#endif #endif // DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_
#!/usr/bin/env python3
# pylint: disable=protected-access, unused-variable, locally-disabled, len-as-condition
"""Lint helper to generate lint summary of source.
Copyright by Contributors.
Borrowed from dmlc-core/scripts/lint.py@939c052
"""
from __future__ import print_function
import argparse
import codecs
import sys
import re
import os
import cpplint
from cpplint import _cpplint_state
from pylint import epylint
CXX_SUFFIX = set(['cc', 'c', 'cpp', 'h', 'cu', 'hpp', 'cuh'])
PYTHON_SUFFIX = set(['py'])
def filepath_enumerate(paths):
"""Enumerate the file paths of all subfiles of the list of paths"""
out = []
for path in paths:
if os.path.isfile(path):
out.append(path)
else:
for root, dirs, files in os.walk(path):
for name in files:
out.append(os.path.normpath(os.path.join(root, name)))
return out
# pylint: disable=useless-object-inheritance
class LintHelper(object):
"""Class to help runing the lint and records summary"""
@staticmethod
def _print_summary_map(strm, result_map, ftype):
"""Print summary of certain result map."""
if len(result_map) == 0:
return 0
npass = sum(1 for x in result_map.values() if len(x) == 0)
strm.write(f'====={npass}/{len(result_map)} {ftype} files passed check=====\n')
for fname, emap in result_map.items():
if len(emap) == 0:
continue
strm.write(
f'{fname}: {sum(emap.values())} Errors of {len(emap)} Categories map={str(emap)}\n')
return len(result_map) - npass
def __init__(self):
self.project_name = None
self.cpp_header_map = {}
self.cpp_src_map = {}
self.python_map = {}
pylint_disable = ['superfluous-parens',
'too-many-instance-attributes',
'too-few-public-methods']
# setup pylint
self.pylint_opts = ['--extension-pkg-whitelist=numpy',
'--disable=' + ','.join(pylint_disable)]
self.pylint_cats = set(['error', 'warning', 'convention', 'refactor'])
# setup cpp lint
cpplint_args = ['--quiet', '--extensions=' + (','.join(CXX_SUFFIX)), '.']
_ = cpplint.ParseArguments(cpplint_args)
cpplint._SetFilters(','.join(['-build/c++11',
'-build/namespaces',
'-build/include,',
'+build/include_what_you_use',
'+build/include_order']))
cpplint._SetCountingStyle('toplevel')
cpplint._line_length = 100
def process_cpp(self, path, suffix):
"""Process a cpp file."""
_cpplint_state.ResetErrorCounts()
cpplint.ProcessFile(str(path), _cpplint_state.verbose_level)
_cpplint_state.PrintErrorCounts()
errors = _cpplint_state.errors_by_category.copy()
if suffix == 'h':
self.cpp_header_map[str(path)] = errors
else:
self.cpp_src_map[str(path)] = errors
def process_python(self, path):
"""Process a python file."""
(pylint_stdout, pylint_stderr) = epylint.py_run(
' '.join([str(path)] + self.pylint_opts), return_std=True)
emap = {}
err = pylint_stderr.read()
if len(err):
print(err)
for line in pylint_stdout:
sys.stderr.write(line)
key = line.split(':')[-1].split('(')[0].strip()
if key not in self.pylint_cats:
continue
if key not in emap:
emap[key] = 1
else:
emap[key] += 1
self.python_map[str(path)] = emap
def print_summary(self, strm):
"""Print summary of lint."""
nerr = 0
nerr += LintHelper._print_summary_map(strm, self.cpp_header_map, 'cpp-header')
nerr += LintHelper._print_summary_map(strm, self.cpp_src_map, 'cpp-source')
nerr += LintHelper._print_summary_map(strm, self.python_map, 'python')
if nerr == 0:
strm.write('All passed!\n')
else:
strm.write(f'{nerr} files failed lint\n')
return nerr
# singleton helper for lint check
_HELPER = LintHelper()
def get_header_guard_dmlc(filename):
"""Get Header Guard Convention for DMLC Projects.
For headers in include, directly use the path
For headers in src, use project name plus path
Examples: with project-name = dmlc
include/dmlc/timer.h -> DMLC_TIMTER_H_
src/io/libsvm_parser.h -> DMLC_IO_LIBSVM_PARSER_H_
"""
fileinfo = cpplint.FileInfo(filename)
file_path_from_root = fileinfo.RepositoryName()
inc_list = ['include', 'api', 'wrapper', 'contrib']
if os.name == 'nt':
inc_list.append("mshadow")
if file_path_from_root.find('src/') != -1 and _HELPER.project_name is not None:
idx = file_path_from_root.find('src/')
file_path_from_root = _HELPER.project_name + file_path_from_root[idx + 3:]
else:
idx = file_path_from_root.find("include/")
if idx != -1:
file_path_from_root = file_path_from_root[idx + 8:]
for spath in inc_list:
prefix = spath + '/'
if file_path_from_root.startswith(prefix):
file_path_from_root = re.sub('^' + prefix, '', file_path_from_root)
break
return re.sub(r'[-./\s]', '_', file_path_from_root).upper() + '_'
cpplint.GetHeaderGuardCPPVariable = get_header_guard_dmlc
def process(fname, allow_type):
"""Process a file."""
fname = str(fname)
arr = fname.rsplit('.', 1)
if fname.find('#') != -1 or arr[-1] not in allow_type:
return
if arr[-1] in CXX_SUFFIX:
_HELPER.process_cpp(fname, arr[-1])
if arr[-1] in PYTHON_SUFFIX:
_HELPER.process_python(fname)
def main():
"""Main entry function."""
parser = argparse.ArgumentParser(description="lint source codes")
parser.add_argument('project', help='project name')
parser.add_argument('filetype', choices=['python', 'cpp', 'all'],
help='source code type')
parser.add_argument('path', nargs='+', help='path to traverse')
parser.add_argument('--exclude_path', nargs='+', default=[],
help='exclude this path, and all subfolders if path is a folder')
parser.add_argument('--quiet', action='store_true', help='run cpplint in quiet mode')
parser.add_argument('--pylint-rc', default=None,
help='pylint rc file')
args = parser.parse_args()
_HELPER.project_name = args.project
if args.pylint_rc is not None:
_HELPER.pylint_opts = ['--rcfile='+args.pylint_rc,]
file_type = args.filetype
allow_type = []
if file_type in ('python', 'all'):
allow_type += PYTHON_SUFFIX
if file_type in ('cpp', 'all'):
allow_type += CXX_SUFFIX
allow_type = set(allow_type)
if sys.version_info.major == 2 and os.name != 'nt':
sys.stderr = codecs.StreamReaderWriter(sys.stderr,
codecs.getreader('utf8'),
codecs.getwriter('utf8'),
'replace')
# get excluded files
excluded_paths = filepath_enumerate(args.exclude_path)
for path in args.path:
if os.path.isfile(path):
normpath = os.path.normpath(path)
if normpath not in excluded_paths:
process(path, allow_type)
else:
for root, dirs, files in os.walk(path):
for name in files:
file_path = os.path.normpath(os.path.join(root, name))
if file_path not in excluded_paths:
process(file_path, allow_type)
nerr = _HELPER.print_summary(sys.stderr)
sys.exit(nerr > 0)
if __name__ == '__main__':
main()
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# cpplint # cpplint
echo 'Checking code style of C++ codes...' echo 'Checking code style of C++ codes...'
python3 third_party/dmlc-core/scripts/lint.py dgl cpp include src || exit 1 python3 tests/lint/lint.py dgl cpp include src || exit 1
# pylint # pylint
echo 'Checking code style of python codes...' echo 'Checking code style of python codes...'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment