Commit d0930731 authored by Abseil Team's avatar Abseil Team Committed by Mark Barolak
Browse files

Googletest export

Fix gmock_gen to use MOCK_METHOD instead of old style macros.  Fix several
related bugs in argument parsing and return types.
- handle commas more correctly in return types
- handle commas correctly in arguments
- handle default values more correctly

PiperOrigin-RevId: 294435093
parent 56de7cc8
...@@ -30,11 +30,11 @@ ...@@ -30,11 +30,11 @@
try: try:
# Python 3.x # Python 3.x
import builtins import builtins
except ImportError: except ImportError:
# Python 2.x # Python 2.x
import __builtin__ as builtins import __builtin__ as builtins
import sys import sys
import traceback import traceback
...@@ -45,15 +45,15 @@ from cpp import utils ...@@ -45,15 +45,15 @@ from cpp import utils
if not hasattr(builtins, 'reversed'): if not hasattr(builtins, 'reversed'):
# Support Python 2.3 and earlier. # Support Python 2.3 and earlier.
def reversed(seq): def reversed(seq):
for i in range(len(seq)-1, -1, -1): for i in range(len(seq)-1, -1, -1):
yield seq[i] yield seq[i]
if not hasattr(builtins, 'next'): if not hasattr(builtins, 'next'):
# Support Python 2.5 and earlier. # Support Python 2.5 and earlier.
def next(obj): def next(obj):
return obj.next() return obj.next()
VISIBILITY_PUBLIC, VISIBILITY_PROTECTED, VISIBILITY_PRIVATE = range(3) VISIBILITY_PUBLIC, VISIBILITY_PROTECTED, VISIBILITY_PRIVATE = range(3)
...@@ -98,1598 +98,1610 @@ _NAMESPACE_POP = 'ns-pop' ...@@ -98,1598 +98,1610 @@ _NAMESPACE_POP = 'ns-pop'
# TODO(nnorwitz): use this as a singleton for templated_types, etc # TODO(nnorwitz): use this as a singleton for templated_types, etc
# where we don't want to create a new empty dict each time. It is also const. # where we don't want to create a new empty dict each time. It is also const.
class _NullDict(object): class _NullDict(object):
__contains__ = lambda self: False __contains__ = lambda self: False
keys = values = items = iterkeys = itervalues = iteritems = lambda self: () keys = values = items = iterkeys = itervalues = iteritems = lambda self: ()
# TODO(nnorwitz): move AST nodes into a separate module. # TODO(nnorwitz): move AST nodes into a separate module.
class Node(object): class Node(object):
"""Base AST node.""" """Base AST node."""
def __init__(self, start, end): def __init__(self, start, end):
self.start = start self.start = start
self.end = end self.end = end
def IsDeclaration(self): def IsDeclaration(self):
"""Returns bool if this node is a declaration.""" """Returns bool if this node is a declaration."""
return False return False
def IsDefinition(self): def IsDefinition(self):
"""Returns bool if this node is a definition.""" """Returns bool if this node is a definition."""
return False return False
def IsExportable(self): def IsExportable(self):
"""Returns bool if this node exportable from a header file.""" """Returns bool if this node exportable from a header file."""
return False return False
def Requires(self, node): def Requires(self, node):
"""Does this AST node require the definition of the node passed in?""" """Does this AST node require the definition of the node passed in?"""
return False return False
def XXX__str__(self): def XXX__str__(self):
return self._StringHelper(self.__class__.__name__, '') return self._StringHelper(self.__class__.__name__, '')
def _StringHelper(self, name, suffix): def _StringHelper(self, name, suffix):
if not utils.DEBUG: if not utils.DEBUG:
return '%s(%s)' % (name, suffix) return '%s(%s)' % (name, suffix)
return '%s(%d, %d, %s)' % (name, self.start, self.end, suffix) return '%s(%d, %d, %s)' % (name, self.start, self.end, suffix)
def __repr__(self): def __repr__(self):
return str(self) return str(self)
class Define(Node): class Define(Node):
def __init__(self, start, end, name, definition): def __init__(self, start, end, name, definition):
Node.__init__(self, start, end) Node.__init__(self, start, end)
self.name = name self.name = name
self.definition = definition self.definition = definition
def __str__(self): def __str__(self):
value = '%s %s' % (self.name, self.definition) value = '%s %s' % (self.name, self.definition)
return self._StringHelper(self.__class__.__name__, value) return self._StringHelper(self.__class__.__name__, value)
class Include(Node): class Include(Node):
def __init__(self, start, end, filename, system): def __init__(self, start, end, filename, system):
Node.__init__(self, start, end) Node.__init__(self, start, end)
self.filename = filename self.filename = filename
self.system = system self.system = system
def __str__(self): def __str__(self):
fmt = '"%s"' fmt = '"%s"'
if self.system: if self.system:
fmt = '<%s>' fmt = '<%s>'
return self._StringHelper(self.__class__.__name__, fmt % self.filename) return self._StringHelper(self.__class__.__name__, fmt % self.filename)
class Goto(Node): class Goto(Node):
def __init__(self, start, end, label): def __init__(self, start, end, label):
Node.__init__(self, start, end) Node.__init__(self, start, end)
self.label = label self.label = label
def __str__(self): def __str__(self):
return self._StringHelper(self.__class__.__name__, str(self.label)) return self._StringHelper(self.__class__.__name__, str(self.label))
class Expr(Node): class Expr(Node):
def __init__(self, start, end, expr): def __init__(self, start, end, expr):
Node.__init__(self, start, end) Node.__init__(self, start, end)
self.expr = expr self.expr = expr
def Requires(self, node): def Requires(self, node):
# TODO(nnorwitz): impl. # TODO(nnorwitz): impl.
return False return False
def __str__(self): def __str__(self):
return self._StringHelper(self.__class__.__name__, str(self.expr)) return self._StringHelper(self.__class__.__name__, str(self.expr))
class Return(Expr): class Return(Expr):
pass pass
class Delete(Expr): class Delete(Expr):
pass pass
class Friend(Expr): class Friend(Expr):
def __init__(self, start, end, expr, namespace): def __init__(self, start, end, expr, namespace):
Expr.__init__(self, start, end, expr) Expr.__init__(self, start, end, expr)
self.namespace = namespace[:] self.namespace = namespace[:]
class Using(Node): class Using(Node):
def __init__(self, start, end, names): def __init__(self, start, end, names):
Node.__init__(self, start, end) Node.__init__(self, start, end)
self.names = names self.names = names
def __str__(self): def __str__(self):
return self._StringHelper(self.__class__.__name__, str(self.names)) return self._StringHelper(self.__class__.__name__, str(self.names))
class Parameter(Node): class Parameter(Node):
def __init__(self, start, end, name, parameter_type, default): def __init__(self, start, end, name, parameter_type, default):
Node.__init__(self, start, end) Node.__init__(self, start, end)
self.name = name self.name = name
self.type = parameter_type self.type = parameter_type
self.default = default self.default = default
def Requires(self, node): def Requires(self, node):
# TODO(nnorwitz): handle namespaces, etc. # TODO(nnorwitz): handle namespaces, etc.
return self.type.name == node.name return self.type.name == node.name
def __str__(self): def __str__(self):
name = str(self.type) name = str(self.type)
suffix = '%s %s' % (name, self.name) suffix = '%s %s' % (name, self.name)
if self.default: if self.default:
suffix += ' = ' + ''.join([d.name for d in self.default]) suffix += ' = ' + ''.join([d.name for d in self.default])
return self._StringHelper(self.__class__.__name__, suffix) return self._StringHelper(self.__class__.__name__, suffix)
class _GenericDeclaration(Node): class _GenericDeclaration(Node):
def __init__(self, start, end, name, namespace): def __init__(self, start, end, name, namespace):
Node.__init__(self, start, end) Node.__init__(self, start, end)
self.name = name self.name = name
self.namespace = namespace[:] self.namespace = namespace[:]
def FullName(self): def FullName(self):
prefix = '' prefix = ''
if self.namespace and self.namespace[-1]: if self.namespace and self.namespace[-1]:
prefix = '::'.join(self.namespace) + '::' prefix = '::'.join(self.namespace) + '::'
return prefix + self.name return prefix + self.name
def _TypeStringHelper(self, suffix): def _TypeStringHelper(self, suffix):
if self.namespace: if self.namespace:
names = [n or '<anonymous>' for n in self.namespace] names = [n or '<anonymous>' for n in self.namespace]
suffix += ' in ' + '::'.join(names) suffix += ' in ' + '::'.join(names)
return self._StringHelper(self.__class__.__name__, suffix) return self._StringHelper(self.__class__.__name__, suffix)
# TODO(nnorwitz): merge with Parameter in some way? # TODO(nnorwitz): merge with Parameter in some way?
class VariableDeclaration(_GenericDeclaration): class VariableDeclaration(_GenericDeclaration):
def __init__(self, start, end, name, var_type, initial_value, namespace): def __init__(self, start, end, name, var_type, initial_value, namespace):
_GenericDeclaration.__init__(self, start, end, name, namespace) _GenericDeclaration.__init__(self, start, end, name, namespace)
self.type = var_type self.type = var_type
self.initial_value = initial_value self.initial_value = initial_value
def Requires(self, node): def Requires(self, node):
# TODO(nnorwitz): handle namespaces, etc. # TODO(nnorwitz): handle namespaces, etc.
return self.type.name == node.name return self.type.name == node.name
def ToString(self): def ToString(self):
"""Return a string that tries to reconstitute the variable decl.""" """Return a string that tries to reconstitute the variable decl."""
suffix = '%s %s' % (self.type, self.name) suffix = '%s %s' % (self.type, self.name)
if self.initial_value: if self.initial_value:
suffix += ' = ' + self.initial_value suffix += ' = ' + self.initial_value
return suffix return suffix
def __str__(self): def __str__(self):
return self._StringHelper(self.__class__.__name__, self.ToString()) return self._StringHelper(self.__class__.__name__, self.ToString())
class Typedef(_GenericDeclaration): class Typedef(_GenericDeclaration):
def __init__(self, start, end, name, alias, namespace): def __init__(self, start, end, name, alias, namespace):
_GenericDeclaration.__init__(self, start, end, name, namespace) _GenericDeclaration.__init__(self, start, end, name, namespace)
self.alias = alias self.alias = alias
def IsDefinition(self): def IsDefinition(self):
return True return True
def IsExportable(self): def IsExportable(self):
return True return True
def Requires(self, node): def Requires(self, node):
# TODO(nnorwitz): handle namespaces, etc. # TODO(nnorwitz): handle namespaces, etc.
name = node.name name = node.name
for token in self.alias: for token in self.alias:
if token is not None and name == token.name: if token is not None and name == token.name:
return True return True
return False return False
def __str__(self): def __str__(self):
suffix = '%s, %s' % (self.name, self.alias) suffix = '%s, %s' % (self.name, self.alias)
return self._TypeStringHelper(suffix) return self._TypeStringHelper(suffix)
class _NestedType(_GenericDeclaration): class _NestedType(_GenericDeclaration):
def __init__(self, start, end, name, fields, namespace): def __init__(self, start, end, name, fields, namespace):
_GenericDeclaration.__init__(self, start, end, name, namespace) _GenericDeclaration.__init__(self, start, end, name, namespace)
self.fields = fields self.fields = fields
def IsDefinition(self): def IsDefinition(self):
return True return True
def IsExportable(self): def IsExportable(self):
return True return True
def __str__(self): def __str__(self):
suffix = '%s, {%s}' % (self.name, self.fields) suffix = '%s, {%s}' % (self.name, self.fields)
return self._TypeStringHelper(suffix) return self._TypeStringHelper(suffix)
class Union(_NestedType): class Union(_NestedType):
pass pass
class Enum(_NestedType): class Enum(_NestedType):
pass pass
class Class(_GenericDeclaration): class Class(_GenericDeclaration):
def __init__(self, start, end, name, bases, templated_types, body, namespace): def __init__(self, start, end, name, bases, templated_types, body, namespace):
_GenericDeclaration.__init__(self, start, end, name, namespace) _GenericDeclaration.__init__(self, start, end, name, namespace)
self.bases = bases self.bases = bases
self.body = body self.body = body
self.templated_types = templated_types self.templated_types = templated_types
def IsDeclaration(self): def IsDeclaration(self):
return self.bases is None and self.body is None return self.bases is None and self.body is None
def IsDefinition(self): def IsDefinition(self):
return not self.IsDeclaration() return not self.IsDeclaration()
def IsExportable(self): def IsExportable(self):
return not self.IsDeclaration() return not self.IsDeclaration()
def Requires(self, node): def Requires(self, node):
# TODO(nnorwitz): handle namespaces, etc. # TODO(nnorwitz): handle namespaces, etc.
if self.bases: if self.bases:
for token_list in self.bases: for token_list in self.bases:
# TODO(nnorwitz): bases are tokens, do name comparision. # TODO(nnorwitz): bases are tokens, do name comparision.
for token in token_list: for token in token_list:
if token.name == node.name: if token.name == node.name:
return True return True
# TODO(nnorwitz): search in body too. # TODO(nnorwitz): search in body too.
return False return False
def __str__(self): def __str__(self):
name = self.name name = self.name
if self.templated_types: if self.templated_types:
name += '<%s>' % self.templated_types name += '<%s>' % self.templated_types
suffix = '%s, %s, %s' % (name, self.bases, self.body) suffix = '%s, %s, %s' % (name, self.bases, self.body)
return self._TypeStringHelper(suffix) return self._TypeStringHelper(suffix)
class Struct(Class): class Struct(Class):
pass pass
class Function(_GenericDeclaration): class Function(_GenericDeclaration):
def __init__(self, start, end, name, return_type, parameters, def __init__(self, start, end, name, return_type, parameters,
modifiers, templated_types, body, namespace): modifiers, templated_types, body, namespace):
_GenericDeclaration.__init__(self, start, end, name, namespace) _GenericDeclaration.__init__(self, start, end, name, namespace)
converter = TypeConverter(namespace) converter = TypeConverter(namespace)
self.return_type = converter.CreateReturnType(return_type) self.return_type = converter.CreateReturnType(return_type)
self.parameters = converter.ToParameters(parameters) self.parameters = converter.ToParameters(parameters)
self.modifiers = modifiers self.modifiers = modifiers
self.body = body self.body = body
self.templated_types = templated_types self.templated_types = templated_types
def IsDeclaration(self): def IsDeclaration(self):
return self.body is None return self.body is None
def IsDefinition(self): def IsDefinition(self):
return self.body is not None return self.body is not None
def IsExportable(self): def IsExportable(self):
if self.return_type and 'static' in self.return_type.modifiers: if self.return_type and 'static' in self.return_type.modifiers:
return False return False
return None not in self.namespace return None not in self.namespace
def Requires(self, node): def Requires(self, node):
if self.parameters: if self.parameters:
# TODO(nnorwitz): parameters are tokens, do name comparision. # TODO(nnorwitz): parameters are tokens, do name comparision.
for p in self.parameters: for p in self.parameters:
if p.name == node.name: if p.name == node.name:
return True return True
# TODO(nnorwitz): search in body too. # TODO(nnorwitz): search in body too.
return False return False
def __str__(self): def __str__(self):
# TODO(nnorwitz): add templated_types. # TODO(nnorwitz): add templated_types.
suffix = ('%s %s(%s), 0x%02x, %s' % suffix = ('%s %s(%s), 0x%02x, %s' %
(self.return_type, self.name, self.parameters, (self.return_type, self.name, self.parameters,
self.modifiers, self.body)) self.modifiers, self.body))
return self._TypeStringHelper(suffix) return self._TypeStringHelper(suffix)
class Method(Function): class Method(Function):
def __init__(self, start, end, name, in_class, return_type, parameters, def __init__(self, start, end, name, in_class, return_type, parameters,
modifiers, templated_types, body, namespace): modifiers, templated_types, body, namespace):
Function.__init__(self, start, end, name, return_type, parameters, Function.__init__(self, start, end, name, return_type, parameters,
modifiers, templated_types, body, namespace) modifiers, templated_types, body, namespace)
# TODO(nnorwitz): in_class could also be a namespace which can # TODO(nnorwitz): in_class could also be a namespace which can
# mess up finding functions properly. # mess up finding functions properly.
self.in_class = in_class self.in_class = in_class
class Type(_GenericDeclaration): class Type(_GenericDeclaration):
"""Type used for any variable (eg class, primitive, struct, etc).""" """Type used for any variable (eg class, primitive, struct, etc)."""
def __init__(self, start, end, name, templated_types, modifiers, def __init__(self, start, end, name, templated_types, modifiers,
reference, pointer, array): reference, pointer, array):
""" """
Args: Args:
name: str name of main type name: str name of main type
templated_types: [Class (Type?)] template type info between <> templated_types: [Class (Type?)] template type info between <>
modifiers: [str] type modifiers (keywords) eg, const, mutable, etc. modifiers: [str] type modifiers (keywords) eg, const, mutable, etc.
reference, pointer, array: bools reference, pointer, array: bools
""" """
_GenericDeclaration.__init__(self, start, end, name, []) _GenericDeclaration.__init__(self, start, end, name, [])
self.templated_types = templated_types self.templated_types = templated_types
if not name and modifiers: if not name and modifiers:
self.name = modifiers.pop() self.name = modifiers.pop()
self.modifiers = modifiers self.modifiers = modifiers
self.reference = reference self.reference = reference
self.pointer = pointer self.pointer = pointer
self.array = array self.array = array
def __str__(self): def __str__(self):
prefix = '' prefix = ''
if self.modifiers: if self.modifiers:
prefix = ' '.join(self.modifiers) + ' ' prefix = ' '.join(self.modifiers) + ' '
name = str(self.name) name = str(self.name)
if self.templated_types: if self.templated_types:
name += '<%s>' % self.templated_types name += '<%s>' % self.templated_types
suffix = prefix + name suffix = prefix + name
if self.reference: if self.reference:
suffix += '&' suffix += '&'
if self.pointer: if self.pointer:
suffix += '*' suffix += '*'
if self.array: if self.array:
suffix += '[]' suffix += '[]'
return self._TypeStringHelper(suffix) return self._TypeStringHelper(suffix)
# By definition, Is* are always False. A Type can only exist in # By definition, Is* are always False. A Type can only exist in
# some sort of variable declaration, parameter, or return value. # some sort of variable declaration, parameter, or return value.
def IsDeclaration(self): def IsDeclaration(self):
return False return False
def IsDefinition(self): def IsDefinition(self):
return False return False
def IsExportable(self): def IsExportable(self):
return False return False
class TypeConverter(object): class TypeConverter(object):
def __init__(self, namespace_stack): def __init__(self, namespace_stack):
self.namespace_stack = namespace_stack self.namespace_stack = namespace_stack
def _GetTemplateEnd(self, tokens, start): def _GetTemplateEnd(self, tokens, start):
count = 1 count = 1
end = start end = start
while 1: while 1:
token = tokens[end] token = tokens[end]
end += 1 end += 1
if token.name == '<': if token.name == '<':
count += 1 count += 1
elif token.name == '>': elif token.name == '>':
count -= 1 count -= 1
if count == 0: if count == 0:
break break
return tokens[start:end-1], end return tokens[start:end-1], end
def ToType(self, tokens): def ToType(self, tokens):
"""Convert [Token,...] to [Class(...), ] useful for base classes. """Convert [Token,...] to [Class(...), ] useful for base classes.
For example, code like class Foo : public Bar<x, y> { ... }; For example, code like class Foo : public Bar<x, y> { ... };
the "Bar<x, y>" portion gets converted to an AST. the "Bar<x, y>" portion gets converted to an AST.
Returns: Returns:
[Class(...), ...] [Class(...), ...]
""" """
result = [] result = []
name_tokens = [] name_tokens = []
reference = pointer = array = False
def AddType(templated_types):
# Partition tokens into name and modifier tokens.
names = []
modifiers = []
for t in name_tokens:
if keywords.IsKeyword(t.name):
modifiers.append(t.name)
else:
names.append(t.name)
name = ''.join(names)
if name_tokens:
result.append(Type(name_tokens[0].start, name_tokens[-1].end,
name, templated_types, modifiers,
reference, pointer, array))
del name_tokens[:]
i = 0
end = len(tokens)
while i < end:
token = tokens[i]
if token.name == '<':
new_tokens, new_end = self._GetTemplateEnd(tokens, i+1)
AddType(self.ToType(new_tokens))
# If there is a comma after the template, we need to consume
# that here otherwise it becomes part of the name.
i = new_end
reference = pointer = array = False reference = pointer = array = False
elif token.name == ',':
def AddType(templated_types): AddType([])
# Partition tokens into name and modifier tokens. reference = pointer = array = False
names = [] elif token.name == '*':
modifiers = [] pointer = True
for t in name_tokens: elif token.name == '&':
if keywords.IsKeyword(t.name): reference = True
modifiers.append(t.name) elif token.name == '[':
else: pointer = True
names.append(t.name) elif token.name == ']':
name = ''.join(names) pass
if name_tokens: else:
result.append(Type(name_tokens[0].start, name_tokens[-1].end, name_tokens.append(token)
name, templated_types, modifiers, i += 1
reference, pointer, array))
del name_tokens[:] if name_tokens:
# No '<' in the tokens, just a simple name and no template.
i = 0 AddType([])
end = len(tokens) return result
while i < end:
token = tokens[i] def DeclarationToParts(self, parts, needs_name_removed):
if token.name == '<': name = None
new_tokens, new_end = self._GetTemplateEnd(tokens, i+1) default = []
AddType(self.ToType(new_tokens)) if needs_name_removed:
# If there is a comma after the template, we need to consume # Handle default (initial) values properly.
# that here otherwise it becomes part of the name. for i, t in enumerate(parts):
i = new_end if t.name == '=':
reference = pointer = array = False default = parts[i+1:]
elif token.name == ',': name = parts[i-1].name
AddType([]) if name == ']' and parts[i-2].name == '[':
reference = pointer = array = False name = parts[i-3].name
elif token.name == '*': i -= 1
pointer = True parts = parts[:i-1]
elif token.name == '&': break
reference = True else:
elif token.name == '[': if parts[-1].token_type == tokenize.NAME:
pointer = True name = parts.pop().name
elif token.name == ']': else:
pass # TODO(nnorwitz): this is a hack that happens for code like
else: # Register(Foo<T>); where it thinks this is a function call
name_tokens.append(token) # but it's actually a declaration.
i += 1 name = '???'
modifiers = []
if name_tokens: type_name = []
# No '<' in the tokens, just a simple name and no template. other_tokens = []
AddType([]) templated_types = []
return result i = 0
end = len(parts)
def DeclarationToParts(self, parts, needs_name_removed): while i < end:
name = None p = parts[i]
default = [] if keywords.IsKeyword(p.name):
if needs_name_removed: modifiers.append(p.name)
# Handle default (initial) values properly. elif p.name == '<':
for i, t in enumerate(parts): templated_tokens, new_end = self._GetTemplateEnd(parts, i+1)
if t.name == '=': templated_types = self.ToType(templated_tokens)
default = parts[i+1:] i = new_end - 1
name = parts[i-1].name # Don't add a spurious :: to data members being initialized.
if name == ']' and parts[i-2].name == '[': next_index = i + 1
name = parts[i-3].name if next_index < end and parts[next_index].name == '::':
i -= 1 i += 1
parts = parts[:i-1] elif p.name in ('[', ']', '='):
break # These are handled elsewhere.
else: other_tokens.append(p)
if parts[-1].token_type == tokenize.NAME: elif p.name not in ('*', '&', '>'):
name = parts.pop().name # Ensure that names have a space between them.
else: if (type_name and type_name[-1].token_type == tokenize.NAME and
# TODO(nnorwitz): this is a hack that happens for code like p.token_type == tokenize.NAME):
# Register(Foo<T>); where it thinks this is a function call type_name.append(tokenize.Token(tokenize.SYNTAX, ' ', 0, 0))
# but it's actually a declaration. type_name.append(p)
name = '???' else:
modifiers = [] other_tokens.append(p)
type_name = [] i += 1
other_tokens = [] type_name = ''.join([t.name for t in type_name])
templated_types = [] return name, type_name, templated_types, modifiers, default, other_tokens
i = 0
end = len(parts) def ToParameters(self, tokens):
while i < end: if not tokens:
p = parts[i] return []
if keywords.IsKeyword(p.name):
modifiers.append(p.name) result = []
elif p.name == '<': name = type_name = ''
templated_tokens, new_end = self._GetTemplateEnd(parts, i+1) type_modifiers = []
templated_types = self.ToType(templated_tokens) pointer = reference = array = False
i = new_end - 1 first_token = None
# Don't add a spurious :: to data members being initialized. default = []
next_index = i + 1
if next_index < end and parts[next_index].name == '::': def AddParameter(end):
i += 1 if default:
elif p.name in ('[', ']', '='): del default[0] # Remove flag.
# These are handled elsewhere. parts = self.DeclarationToParts(type_modifiers, True)
other_tokens.append(p) (name, type_name, templated_types, modifiers,
elif p.name not in ('*', '&', '>'): unused_default, unused_other_tokens) = parts
# Ensure that names have a space between them. parameter_type = Type(first_token.start, first_token.end,
if (type_name and type_name[-1].token_type == tokenize.NAME and type_name, templated_types, modifiers,
p.token_type == tokenize.NAME): reference, pointer, array)
type_name.append(tokenize.Token(tokenize.SYNTAX, ' ', 0, 0)) p = Parameter(first_token.start, end, name,
type_name.append(p) parameter_type, default)
else: result.append(p)
other_tokens.append(p)
i += 1 template_count = 0
type_name = ''.join([t.name for t in type_name]) brace_count = 0
return name, type_name, templated_types, modifiers, default, other_tokens for s in tokens:
if not first_token:
def ToParameters(self, tokens): first_token = s
if not tokens:
return [] # Check for braces before templates, as we can have unmatched '<>'
# inside default arguments.
result = [] if s.name == '{':
brace_count += 1
elif s.name == '}':
brace_count -= 1
if brace_count > 0:
type_modifiers.append(s)
continue
if s.name == '<':
template_count += 1
elif s.name == '>':
template_count -= 1
if template_count > 0:
type_modifiers.append(s)
continue
if s.name == ',':
AddParameter(s.start)
name = type_name = '' name = type_name = ''
type_modifiers = [] type_modifiers = []
pointer = reference = array = False pointer = reference = array = False
first_token = None first_token = None
default = [] default = []
elif s.name == '*':
def AddParameter(end): pointer = True
if default: elif s.name == '&':
del default[0] # Remove flag. reference = True
parts = self.DeclarationToParts(type_modifiers, True) elif s.name == '[':
(name, type_name, templated_types, modifiers, array = True
unused_default, unused_other_tokens) = parts elif s.name == ']':
parameter_type = Type(first_token.start, first_token.end, pass # Just don't add to type_modifiers.
type_name, templated_types, modifiers, elif s.name == '=':
reference, pointer, array) # Got a default value. Add any value (None) as a flag.
p = Parameter(first_token.start, end, name, default.append(None)
parameter_type, default) elif default:
result.append(p) default.append(s)
else:
template_count = 0 type_modifiers.append(s)
for s in tokens: AddParameter(tokens[-1].end)
if not first_token: return result
first_token = s
if s.name == '<': def CreateReturnType(self, return_type_seq):
template_count += 1 if not return_type_seq:
elif s.name == '>': return None
template_count -= 1 start = return_type_seq[0].start
if template_count > 0: end = return_type_seq[-1].end
type_modifiers.append(s) _, name, templated_types, modifiers, default, other_tokens = \
continue self.DeclarationToParts(return_type_seq, False)
names = [n.name for n in other_tokens]
if s.name == ',': reference = '&' in names
AddParameter(s.start) pointer = '*' in names
name = type_name = '' array = '[' in names
type_modifiers = [] return Type(start, end, name, templated_types, modifiers,
pointer = reference = array = False reference, pointer, array)
first_token = None
default = [] def GetTemplateIndices(self, names):
elif s.name == '*': # names is a list of strings.
pointer = True start = names.index('<')
elif s.name == '&': end = len(names) - 1
reference = True while end > 0:
elif s.name == '[': if names[end] == '>':
array = True break
elif s.name == ']': end -= 1
pass # Just don't add to type_modifiers. return start, end+1
elif s.name == '=':
# Got a default value. Add any value (None) as a flag.
default.append(None)
elif default:
default.append(s)
else:
type_modifiers.append(s)
AddParameter(tokens[-1].end)
return result
def CreateReturnType(self, return_type_seq):
if not return_type_seq:
return None
start = return_type_seq[0].start
end = return_type_seq[-1].end
_, name, templated_types, modifiers, default, other_tokens = \
self.DeclarationToParts(return_type_seq, False)
names = [n.name for n in other_tokens]
reference = '&' in names
pointer = '*' in names
array = '[' in names
return Type(start, end, name, templated_types, modifiers,
reference, pointer, array)
def GetTemplateIndices(self, names):
# names is a list of strings.
start = names.index('<')
end = len(names) - 1
while end > 0:
if names[end] == '>':
break
end -= 1
return start, end+1
class AstBuilder(object): class AstBuilder(object):
def __init__(self, token_stream, filename, in_class='', visibility=None, def __init__(self, token_stream, filename, in_class='', visibility=None,
namespace_stack=[]): namespace_stack=[]):
self.tokens = token_stream self.tokens = token_stream
self.filename = filename self.filename = filename
# TODO(nnorwitz): use a better data structure (deque) for the queue. # TODO(nnorwitz): use a better data structure (deque) for the queue.
# Switching directions of the "queue" improved perf by about 25%. # Switching directions of the "queue" improved perf by about 25%.
# Using a deque should be even better since we access from both sides. # Using a deque should be even better since we access from both sides.
self.token_queue = [] self.token_queue = []
self.namespace_stack = namespace_stack[:] self.namespace_stack = namespace_stack[:]
self.in_class = in_class self.in_class = in_class
if in_class is None: if in_class is None:
self.in_class_name_only = None self.in_class_name_only = None
else: else:
self.in_class_name_only = in_class.split('::')[-1] self.in_class_name_only = in_class.split('::')[-1]
self.visibility = visibility self.visibility = visibility
self.in_function = False self.in_function = False
self.current_token = None self.current_token = None
# Keep the state whether we are currently handling a typedef or not. # Keep the state whether we are currently handling a typedef or not.
self._handling_typedef = False self._handling_typedef = False
self.converter = TypeConverter(self.namespace_stack) self.converter = TypeConverter(self.namespace_stack)
def HandleError(self, msg, token): def HandleError(self, msg, token):
printable_queue = list(reversed(self.token_queue[-20:])) printable_queue = list(reversed(self.token_queue[-20:]))
sys.stderr.write('Got %s in %s @ %s %s\n' % sys.stderr.write('Got %s in %s @ %s %s\n' %
(msg, self.filename, token, printable_queue)) (msg, self.filename, token, printable_queue))
def Generate(self): def Generate(self):
while 1: while 1:
token = self._GetNextToken() token = self._GetNextToken()
if not token: if not token:
break break
# Get the next token. # Get the next token.
self.current_token = token self.current_token = token
# Dispatch on the next token type. # Dispatch on the next token type.
if token.token_type == _INTERNAL_TOKEN: if token.token_type == _INTERNAL_TOKEN:
if token.name == _NAMESPACE_POP: if token.name == _NAMESPACE_POP:
self.namespace_stack.pop() self.namespace_stack.pop()
continue continue
try: try:
result = self._GenerateOne(token) result = self._GenerateOne(token)
if result is not None: if result is not None:
yield result yield result
except: except:
self.HandleError('exception', token) self.HandleError('exception', token)
raise raise
def _CreateVariable(self, pos_token, name, type_name, type_modifiers, def _CreateVariable(self, pos_token, name, type_name, type_modifiers,
ref_pointer_name_seq, templated_types, value=None): ref_pointer_name_seq, templated_types, value=None):
reference = '&' in ref_pointer_name_seq reference = '&' in ref_pointer_name_seq
pointer = '*' in ref_pointer_name_seq pointer = '*' in ref_pointer_name_seq
array = '[' in ref_pointer_name_seq array = '[' in ref_pointer_name_seq
var_type = Type(pos_token.start, pos_token.end, type_name, var_type = Type(pos_token.start, pos_token.end, type_name,
templated_types, type_modifiers, templated_types, type_modifiers,
reference, pointer, array) reference, pointer, array)
return VariableDeclaration(pos_token.start, pos_token.end, return VariableDeclaration(pos_token.start, pos_token.end,
name, var_type, value, self.namespace_stack) name, var_type, value, self.namespace_stack)
def _GenerateOne(self, token): def _GenerateOne(self, token):
if token.token_type == tokenize.NAME: if token.token_type == tokenize.NAME:
if (keywords.IsKeyword(token.name) and if (keywords.IsKeyword(token.name) and
not keywords.IsBuiltinType(token.name)): not keywords.IsBuiltinType(token.name)):
if token.name == 'enum': if token.name == 'enum':
# Pop the next token and only put it back if it's not # Pop the next token and only put it back if it's not
# 'class'. This allows us to support the two-token # 'class'. This allows us to support the two-token
# 'enum class' keyword as if it were simply 'enum'. # 'enum class' keyword as if it were simply 'enum'.
next = self._GetNextToken() next = self._GetNextToken()
if next.name != 'class': if next.name != 'class':
self._AddBackToken(next) self._AddBackToken(next)
method = getattr(self, 'handle_' + token.name) method = getattr(self, 'handle_' + token.name)
return method() return method()
elif token.name == self.in_class_name_only: elif token.name == self.in_class_name_only:
# The token name is the same as the class, must be a ctor if # The token name is the same as the class, must be a ctor if
# there is a paren. Otherwise, it's the return type. # there is a paren. Otherwise, it's the return type.
# Peek ahead to get the next token to figure out which. # Peek ahead to get the next token to figure out which.
next = self._GetNextToken() next = self._GetNextToken()
self._AddBackToken(next) self._AddBackToken(next)
if next.token_type == tokenize.SYNTAX and next.name == '(': if next.token_type == tokenize.SYNTAX and next.name == '(':
return self._GetMethod([token], FUNCTION_CTOR, None, True) return self._GetMethod([token], FUNCTION_CTOR, None, True)
# Fall through--handle like any other method. # Fall through--handle like any other method.
# Handle data or function declaration/definition. # Handle data or function declaration/definition.
syntax = tokenize.SYNTAX syntax = tokenize.SYNTAX
temp_tokens, last_token = \ temp_tokens, last_token = \
self._GetVarTokensUpToIgnoringTemplates(syntax, self._GetVarTokensUpToIgnoringTemplates(syntax,
'(', ';', '{', '[') '(', ';', '{', '[')
temp_tokens.insert(0, token) temp_tokens.insert(0, token)
if last_token.name == '(': if last_token.name == '(':
# If there is an assignment before the paren, # If there is an assignment before the paren,
# this is an expression, not a method. # this is an expression, not a method.
expr = bool([e for e in temp_tokens if e.name == '=']) expr = bool([e for e in temp_tokens if e.name == '='])
if expr: if expr:
new_temp = self._GetTokensUpTo(tokenize.SYNTAX, ';') new_temp = self._GetTokensUpTo(tokenize.SYNTAX, ';')
temp_tokens.append(last_token) temp_tokens.append(last_token)
temp_tokens.extend(new_temp) temp_tokens.extend(new_temp)
last_token = tokenize.Token(tokenize.SYNTAX, ';', 0, 0) last_token = tokenize.Token(tokenize.SYNTAX, ';', 0, 0)
if last_token.name == '[': if last_token.name == '[':
# Handle array, this isn't a method, unless it's an operator. # Handle array, this isn't a method, unless it's an operator.
# TODO(nnorwitz): keep the size somewhere. # TODO(nnorwitz): keep the size somewhere.
# unused_size = self._GetTokensUpTo(tokenize.SYNTAX, ']') # unused_size = self._GetTokensUpTo(tokenize.SYNTAX, ']')
temp_tokens.append(last_token) temp_tokens.append(last_token)
if temp_tokens[-2].name == 'operator': if temp_tokens[-2].name == 'operator':
temp_tokens.append(self._GetNextToken()) temp_tokens.append(self._GetNextToken())
else:
temp_tokens2, last_token = \
self._GetVarTokensUpTo(tokenize.SYNTAX, ';')
temp_tokens.extend(temp_tokens2)
if last_token.name == ';':
# Handle data, this isn't a method.
parts = self.converter.DeclarationToParts(temp_tokens, True)
(name, type_name, templated_types, modifiers, default,
unused_other_tokens) = parts
t0 = temp_tokens[0]
names = [t.name for t in temp_tokens]
if templated_types:
start, end = self.converter.GetTemplateIndices(names)
names = names[:start] + names[end:]
default = ''.join([t.name for t in default])
return self._CreateVariable(t0, name, type_name, modifiers,
names, templated_types, default)
if last_token.name == '{':
self._AddBackTokens(temp_tokens[1:])
self._AddBackToken(last_token)
method_name = temp_tokens[0].name
method = getattr(self, 'handle_' + method_name, None)
if not method:
# Must be declaring a variable.
# TODO(nnorwitz): handle the declaration.
return None
return method()
return self._GetMethod(temp_tokens, 0, None, False)
elif token.token_type == tokenize.SYNTAX:
if token.name == '~' and self.in_class:
# Must be a dtor (probably not in method body).
token = self._GetNextToken()
# self.in_class can contain A::Name, but the dtor will only
# be Name. Make sure to compare against the right value.
if (token.token_type == tokenize.NAME and
token.name == self.in_class_name_only):
return self._GetMethod([token], FUNCTION_DTOR, None, True)
# TODO(nnorwitz): handle a lot more syntax.
elif token.token_type == tokenize.PREPROCESSOR:
# TODO(nnorwitz): handle more preprocessor directives.
# token starts with a #, so remove it and strip whitespace.
name = token.name[1:].lstrip()
if name.startswith('include'):
# Remove "include".
name = name[7:].strip()
assert name
# Handle #include \<newline> "header-on-second-line.h".
if name.startswith('\\'):
name = name[1:].strip()
assert name[0] in '<"', token
assert name[-1] in '>"', token
system = name[0] == '<'
filename = name[1:-1]
return Include(token.start, token.end, filename, system)
if name.startswith('define'):
# Remove "define".
name = name[6:].strip()
assert name
value = ''
for i, c in enumerate(name):
if c.isspace():
value = name[i:].lstrip()
name = name[:i]
break
return Define(token.start, token.end, name, value)
if name.startswith('if') and name[2:3].isspace():
condition = name[3:].strip()
if condition.startswith('0') or condition.startswith('(0)'):
self._SkipIf0Blocks()
return None
def _GetTokensUpTo(self, expected_token_type, expected_token):
return self._GetVarTokensUpTo(expected_token_type, expected_token)[0]
def _GetVarTokensUpTo(self, expected_token_type, *expected_tokens):
last_token = self._GetNextToken()
tokens = []
while (last_token.token_type != expected_token_type or
last_token.name not in expected_tokens):
tokens.append(last_token)
last_token = self._GetNextToken()
return tokens, last_token
# Same as _GetVarTokensUpTo, but skips over '<...>' which could contain an
# expected token.
def _GetVarTokensUpToIgnoringTemplates(self, expected_token_type,
*expected_tokens):
last_token = self._GetNextToken()
tokens = []
nesting = 0
while (nesting > 0 or
last_token.token_type != expected_token_type or
last_token.name not in expected_tokens):
tokens.append(last_token)
last_token = self._GetNextToken()
if last_token.name == '<':
nesting += 1
elif last_token.name == '>':
nesting -= 1
return tokens, last_token
# TODO(nnorwitz): remove _IgnoreUpTo() it shouldn't be necesary.
def _IgnoreUpTo(self, token_type, token):
unused_tokens = self._GetTokensUpTo(token_type, token)
def _SkipIf0Blocks(self):
count = 1
while 1:
token = self._GetNextToken()
if token.token_type != tokenize.PREPROCESSOR:
continue
name = token.name[1:].lstrip()
if name.startswith('endif'):
count -= 1
if count == 0:
break
elif name.startswith('if'):
count += 1
def _GetMatchingChar(self, open_paren, close_paren, GetNextToken=None):
if GetNextToken is None:
GetNextToken = self._GetNextToken
# Assumes the current token is open_paren and we will consume
# and return up to the close_paren.
count = 1
token = GetNextToken()
while 1:
if token.token_type == tokenize.SYNTAX:
if token.name == open_paren:
count += 1
elif token.name == close_paren:
count -= 1
if count == 0:
break
yield token
token = GetNextToken()
yield token
def _GetParameters(self):
return self._GetMatchingChar('(', ')')
def GetScope(self):
return self._GetMatchingChar('{', '}')
def _GetNextToken(self):
if self.token_queue:
return self.token_queue.pop()
try:
return next(self.tokens)
except StopIteration:
return
def _AddBackToken(self, token):
if token.whence == tokenize.WHENCE_STREAM:
token.whence = tokenize.WHENCE_QUEUE
self.token_queue.insert(0, token)
else: else:
assert token.whence == tokenize.WHENCE_QUEUE, token temp_tokens2, last_token = \
self.token_queue.append(token) self._GetVarTokensUpTo(tokenize.SYNTAX, ';')
temp_tokens.extend(temp_tokens2)
def _AddBackTokens(self, tokens):
if tokens: if last_token.name == ';':
if tokens[-1].whence == tokenize.WHENCE_STREAM: # Handle data, this isn't a method.
for token in tokens: parts = self.converter.DeclarationToParts(temp_tokens, True)
token.whence = tokenize.WHENCE_QUEUE (name, type_name, templated_types, modifiers, default,
self.token_queue[:0] = reversed(tokens) unused_other_tokens) = parts
else:
assert tokens[-1].whence == tokenize.WHENCE_QUEUE, tokens t0 = temp_tokens[0]
self.token_queue.extend(reversed(tokens)) names = [t.name for t in temp_tokens]
if templated_types:
def GetName(self, seq=None): start, end = self.converter.GetTemplateIndices(names)
"""Returns ([tokens], next_token_info).""" names = names[:start] + names[end:]
GetNextToken = self._GetNextToken default = ''.join([t.name for t in default])
if seq is not None: return self._CreateVariable(t0, name, type_name, modifiers,
it = iter(seq) names, templated_types, default)
GetNextToken = lambda: next(it) if last_token.name == '{':
next_token = GetNextToken() self._AddBackTokens(temp_tokens[1:])
tokens = [] self._AddBackToken(last_token)
last_token_was_name = False method_name = temp_tokens[0].name
while (next_token.token_type == tokenize.NAME or method = getattr(self, 'handle_' + method_name, None)
(next_token.token_type == tokenize.SYNTAX and if not method:
next_token.name in ('::', '<'))): # Must be declaring a variable.
# Two NAMEs in a row means the identifier should terminate. # TODO(nnorwitz): handle the declaration.
# It's probably some sort of variable declaration. return None
if last_token_was_name and next_token.token_type == tokenize.NAME: return method()
break return self._GetMethod(temp_tokens, 0, None, False)
last_token_was_name = next_token.token_type == tokenize.NAME elif token.token_type == tokenize.SYNTAX:
tokens.append(next_token) if token.name == '~' and self.in_class:
# Handle templated names. # Must be a dtor (probably not in method body).
if next_token.name == '<': token = self._GetNextToken()
tokens.extend(self._GetMatchingChar('<', '>', GetNextToken)) # self.in_class can contain A::Name, but the dtor will only
last_token_was_name = True # be Name. Make sure to compare against the right value.
next_token = GetNextToken() if (token.token_type == tokenize.NAME and
return tokens, next_token token.name == self.in_class_name_only):
return self._GetMethod([token], FUNCTION_DTOR, None, True)
def GetMethod(self, modifiers, templated_types): # TODO(nnorwitz): handle a lot more syntax.
return_type_and_name = self._GetTokensUpTo(tokenize.SYNTAX, '(') elif token.token_type == tokenize.PREPROCESSOR:
assert len(return_type_and_name) >= 1 # TODO(nnorwitz): handle more preprocessor directives.
return self._GetMethod(return_type_and_name, modifiers, templated_types, # token starts with a #, so remove it and strip whitespace.
False) name = token.name[1:].lstrip()
if name.startswith('include'):
def _GetMethod(self, return_type_and_name, modifiers, templated_types, # Remove "include".
get_paren): name = name[7:].strip()
template_portion = None assert name
if get_paren: # Handle #include \<newline> "header-on-second-line.h".
token = self._GetNextToken() if name.startswith('\\'):
assert token.token_type == tokenize.SYNTAX, token name = name[1:].strip()
if token.name == '<': assert name[0] in '<"', token
# Handle templatized dtors. assert name[-1] in '>"', token
template_portion = [token] system = name[0] == '<'
template_portion.extend(self._GetMatchingChar('<', '>')) filename = name[1:-1]
token = self._GetNextToken() return Include(token.start, token.end, filename, system)
assert token.token_type == tokenize.SYNTAX, token if name.startswith('define'):
assert token.name == '(', token # Remove "define".
name = name[6:].strip()
name = return_type_and_name.pop() assert name
# Handle templatized ctors. value = ''
if name.name == '>': for i, c in enumerate(name):
index = 1 if c.isspace():
while return_type_and_name[index].name != '<': value = name[i:].lstrip()
index += 1 name = name[:i]
template_portion = return_type_and_name[index:] + [name] break
del return_type_and_name[index:] return Define(token.start, token.end, name, value)
name = return_type_and_name.pop() if name.startswith('if') and name[2:3].isspace():
elif name.name == ']': condition = name[3:].strip()
rt = return_type_and_name if condition.startswith('0') or condition.startswith('(0)'):
assert rt[-1].name == '[', return_type_and_name self._SkipIf0Blocks()
assert rt[-2].name == 'operator', return_type_and_name return None
name_seq = return_type_and_name[-2:]
del return_type_and_name[-2:] def _GetTokensUpTo(self, expected_token_type, expected_token):
name = tokenize.Token(tokenize.NAME, 'operator[]', return self._GetVarTokensUpTo(expected_token_type, expected_token)[0]
name_seq[0].start, name.end)
# Get the open paren so _GetParameters() below works. def _GetVarTokensUpTo(self, expected_token_type, *expected_tokens):
unused_open_paren = self._GetNextToken() last_token = self._GetNextToken()
tokens = []
# TODO(nnorwitz): store template_portion. while (last_token.token_type != expected_token_type or
return_type = return_type_and_name last_token.name not in expected_tokens):
indices = name tokens.append(last_token)
if return_type: last_token = self._GetNextToken()
indices = return_type[0] return tokens, last_token
# Force ctor for templatized ctors. # Same as _GetVarTokensUpTo, but skips over '<...>' which could contain an
if name.name == self.in_class and not modifiers: # expected token.
modifiers |= FUNCTION_CTOR def _GetVarTokensUpToIgnoringTemplates(self, expected_token_type,
parameters = list(self._GetParameters()) *expected_tokens):
del parameters[-1] # Remove trailing ')'. last_token = self._GetNextToken()
tokens = []
# Handling operator() is especially weird. nesting = 0
if name.name == 'operator' and not parameters: while (nesting > 0 or
token = self._GetNextToken() last_token.token_type != expected_token_type or
assert token.name == '(', token last_token.name not in expected_tokens):
parameters = list(self._GetParameters()) tokens.append(last_token)
del parameters[-1] # Remove trailing ')'. last_token = self._GetNextToken()
if last_token.name == '<':
nesting += 1
elif last_token.name == '>':
nesting -= 1
return tokens, last_token
# TODO(nnorwitz): remove _IgnoreUpTo() it shouldn't be necesary.
def _IgnoreUpTo(self, token_type, token):
unused_tokens = self._GetTokensUpTo(token_type, token)
def _SkipIf0Blocks(self):
count = 1
while 1:
token = self._GetNextToken()
if token.token_type != tokenize.PREPROCESSOR:
continue
name = token.name[1:].lstrip()
if name.startswith('endif'):
count -= 1
if count == 0:
break
elif name.startswith('if'):
count += 1
def _GetMatchingChar(self, open_paren, close_paren, GetNextToken=None):
if GetNextToken is None:
GetNextToken = self._GetNextToken
# Assumes the current token is open_paren and we will consume
# and return up to the close_paren.
count = 1
token = GetNextToken()
while 1:
if token.token_type == tokenize.SYNTAX:
if token.name == open_paren:
count += 1
elif token.name == close_paren:
count -= 1
if count == 0:
break
yield token
token = GetNextToken()
yield token
def _GetParameters(self):
return self._GetMatchingChar('(', ')')
def GetScope(self):
return self._GetMatchingChar('{', '}')
def _GetNextToken(self):
if self.token_queue:
return self.token_queue.pop()
try:
return next(self.tokens)
except StopIteration:
return
def _AddBackToken(self, token):
if token.whence == tokenize.WHENCE_STREAM:
token.whence = tokenize.WHENCE_QUEUE
self.token_queue.insert(0, token)
else:
assert token.whence == tokenize.WHENCE_QUEUE, token
self.token_queue.append(token)
def _AddBackTokens(self, tokens):
if tokens:
if tokens[-1].whence == tokenize.WHENCE_STREAM:
for token in tokens:
token.whence = tokenize.WHENCE_QUEUE
self.token_queue[:0] = reversed(tokens)
else:
assert tokens[-1].whence == tokenize.WHENCE_QUEUE, tokens
self.token_queue.extend(reversed(tokens))
def GetName(self, seq=None):
"""Returns ([tokens], next_token_info)."""
GetNextToken = self._GetNextToken
if seq is not None:
it = iter(seq)
GetNextToken = lambda: next(it)
next_token = GetNextToken()
tokens = []
last_token_was_name = False
while (next_token.token_type == tokenize.NAME or
(next_token.token_type == tokenize.SYNTAX and
next_token.name in ('::', '<'))):
# Two NAMEs in a row means the identifier should terminate.
# It's probably some sort of variable declaration.
if last_token_was_name and next_token.token_type == tokenize.NAME:
break
last_token_was_name = next_token.token_type == tokenize.NAME
tokens.append(next_token)
# Handle templated names.
if next_token.name == '<':
tokens.extend(self._GetMatchingChar('<', '>', GetNextToken))
last_token_was_name = True
next_token = GetNextToken()
return tokens, next_token
def GetMethod(self, modifiers, templated_types):
return_type_and_name = self._GetTokensUpTo(tokenize.SYNTAX, '(')
assert len(return_type_and_name) >= 1
return self._GetMethod(return_type_and_name, modifiers, templated_types,
False)
def _GetMethod(self, return_type_and_name, modifiers, templated_types,
get_paren):
template_portion = None
if get_paren:
token = self._GetNextToken()
assert token.token_type == tokenize.SYNTAX, token
if token.name == '<':
# Handle templatized dtors.
template_portion = [token]
template_portion.extend(self._GetMatchingChar('<', '>'))
token = self._GetNextToken()
assert token.token_type == tokenize.SYNTAX, token
assert token.name == '(', token
name = return_type_and_name.pop()
# Handle templatized ctors.
if name.name == '>':
index = 1
while return_type_and_name[index].name != '<':
index += 1
template_portion = return_type_and_name[index:] + [name]
del return_type_and_name[index:]
name = return_type_and_name.pop()
elif name.name == ']':
rt = return_type_and_name
assert rt[-1].name == '[', return_type_and_name
assert rt[-2].name == 'operator', return_type_and_name
name_seq = return_type_and_name[-2:]
del return_type_and_name[-2:]
name = tokenize.Token(tokenize.NAME, 'operator[]',
name_seq[0].start, name.end)
# Get the open paren so _GetParameters() below works.
unused_open_paren = self._GetNextToken()
# TODO(nnorwitz): store template_portion.
return_type = return_type_and_name
indices = name
if return_type:
indices = return_type[0]
# Force ctor for templatized ctors.
if name.name == self.in_class and not modifiers:
modifiers |= FUNCTION_CTOR
parameters = list(self._GetParameters())
del parameters[-1] # Remove trailing ')'.
# Handling operator() is especially weird.
if name.name == 'operator' and not parameters:
token = self._GetNextToken()
assert token.name == '(', token
parameters = list(self._GetParameters())
del parameters[-1] # Remove trailing ')'.
token = self._GetNextToken()
while token.token_type == tokenize.NAME:
modifier_token = token
token = self._GetNextToken()
if modifier_token.name == 'const':
modifiers |= FUNCTION_CONST
elif modifier_token.name == '__attribute__':
# TODO(nnorwitz): handle more __attribute__ details.
modifiers |= FUNCTION_ATTRIBUTE
assert token.name == '(', token
# Consume everything between the (parens).
unused_tokens = list(self._GetMatchingChar('(', ')'))
token = self._GetNextToken()
elif modifier_token.name == 'throw':
modifiers |= FUNCTION_THROW
assert token.name == '(', token
# Consume everything between the (parens).
unused_tokens = list(self._GetMatchingChar('(', ')'))
token = self._GetNextToken()
elif modifier_token.name == 'override':
modifiers |= FUNCTION_OVERRIDE
elif modifier_token.name == modifier_token.name.upper():
# HACK(nnorwitz): assume that all upper-case names
# are some macro we aren't expanding.
modifiers |= FUNCTION_UNKNOWN_ANNOTATION
else:
self.HandleError('unexpected token', modifier_token)
assert token.token_type == tokenize.SYNTAX, token
# Handle ctor initializers.
if token.name == ':':
# TODO(nnorwitz): anything else to handle for initializer list?
while token.name != ';' and token.name != '{':
token = self._GetNextToken() token = self._GetNextToken()
while token.token_type == tokenize.NAME:
modifier_token = token
token = self._GetNextToken()
if modifier_token.name == 'const':
modifiers |= FUNCTION_CONST
elif modifier_token.name == '__attribute__':
# TODO(nnorwitz): handle more __attribute__ details.
modifiers |= FUNCTION_ATTRIBUTE
assert token.name == '(', token
# Consume everything between the (parens).
unused_tokens = list(self._GetMatchingChar('(', ')'))
token = self._GetNextToken()
elif modifier_token.name == 'throw':
modifiers |= FUNCTION_THROW
assert token.name == '(', token
# Consume everything between the (parens).
unused_tokens = list(self._GetMatchingChar('(', ')'))
token = self._GetNextToken()
elif modifier_token.name == 'override':
modifiers |= FUNCTION_OVERRIDE
elif modifier_token.name == modifier_token.name.upper():
# HACK(nnorwitz): assume that all upper-case names
# are some macro we aren't expanding.
modifiers |= FUNCTION_UNKNOWN_ANNOTATION
else:
self.HandleError('unexpected token', modifier_token)
# Handle pointer to functions that are really data but look
# like method declarations.
if token.name == '(':
if parameters[0].name == '*':
# name contains the return type.
name = parameters.pop()
# parameters contains the name of the data.
modifiers = [p.name for p in parameters]
# Already at the ( to open the parameter list.
function_parameters = list(self._GetMatchingChar('(', ')'))
del function_parameters[-1] # Remove trailing ')'.
# TODO(nnorwitz): store the function_parameters.
token = self._GetNextToken()
assert token.token_type == tokenize.SYNTAX, token assert token.token_type == tokenize.SYNTAX, token
# Handle ctor initializers. assert token.name == ';', token
if token.name == ':': return self._CreateVariable(indices, name.name, indices.name,
# TODO(nnorwitz): anything else to handle for initializer list? modifiers, '', None)
while token.name != ';' and token.name != '{': # At this point, we got something like:
token = self._GetNextToken() # return_type (type::*name_)(params);
# This is a data member called name_ that is a function pointer.
# Handle pointer to functions that are really data but look # With this code: void (sq_type::*field_)(string&);
# like method declarations. # We get: name=void return_type=[] parameters=sq_type ... field_
if token.name == '(': # TODO(nnorwitz): is return_type always empty?
if parameters[0].name == '*': # TODO(nnorwitz): this isn't even close to being correct.
# name contains the return type. # Just put in something so we don't crash and can move on.
name = parameters.pop() real_name = parameters[-1]
# parameters contains the name of the data. modifiers = [p.name for p in self._GetParameters()]
modifiers = [p.name for p in parameters] del modifiers[-1] # Remove trailing ')'.
# Already at the ( to open the parameter list. return self._CreateVariable(indices, real_name.name, indices.name,
function_parameters = list(self._GetMatchingChar('(', ')')) modifiers, '', None)
del function_parameters[-1] # Remove trailing ')'.
# TODO(nnorwitz): store the function_parameters. if token.name == '{':
token = self._GetNextToken() body = list(self.GetScope())
assert token.token_type == tokenize.SYNTAX, token del body[-1] # Remove trailing '}'.
assert token.name == ';', token else:
return self._CreateVariable(indices, name.name, indices.name, body = None
modifiers, '', None) if token.name == '=':
# At this point, we got something like: token = self._GetNextToken()
# return_type (type::*name_)(params);
# This is a data member called name_ that is a function pointer.
# With this code: void (sq_type::*field_)(string&);
# We get: name=void return_type=[] parameters=sq_type ... field_
# TODO(nnorwitz): is return_type always empty?
# TODO(nnorwitz): this isn't even close to being correct.
# Just put in something so we don't crash and can move on.
real_name = parameters[-1]
modifiers = [p.name for p in self._GetParameters()]
del modifiers[-1] # Remove trailing ')'.
return self._CreateVariable(indices, real_name.name, indices.name,
modifiers, '', None)
if token.name == '{':
body = list(self.GetScope())
del body[-1] # Remove trailing '}'.
else:
body = None
if token.name == '=':
token = self._GetNextToken()
if token.name == 'default' or token.name == 'delete':
# Ignore explicitly defaulted and deleted special members
# in C++11.
token = self._GetNextToken()
else:
# Handle pure-virtual declarations.
assert token.token_type == tokenize.CONSTANT, token
assert token.name == '0', token
modifiers |= FUNCTION_PURE_VIRTUAL
token = self._GetNextToken()
if token.name == '[':
# TODO(nnorwitz): store tokens and improve parsing.
# template <typename T, size_t N> char (&ASH(T (&seq)[N]))[N];
tokens = list(self._GetMatchingChar('[', ']'))
token = self._GetNextToken()
assert token.name == ';', (token, return_type_and_name, parameters)
# Looks like we got a method, not a function.
if len(return_type) > 2 and return_type[-1].name == '::':
return_type, in_class = \
self._GetReturnTypeAndClassName(return_type)
return Method(indices.start, indices.end, name.name, in_class,
return_type, parameters, modifiers, templated_types,
body, self.namespace_stack)
return Function(indices.start, indices.end, name.name, return_type,
parameters, modifiers, templated_types, body,
self.namespace_stack)
def _GetReturnTypeAndClassName(self, token_seq):
# Splitting the return type from the class name in a method
# can be tricky. For example, Return::Type::Is::Hard::To::Find().
# Where is the return type and where is the class name?
# The heuristic used is to pull the last name as the class name.
# This includes all the templated type info.
# TODO(nnorwitz): if there is only One name like in the
# example above, punt and assume the last bit is the class name.
# Ignore a :: prefix, if exists so we can find the first real name.
i = 0
if token_seq[0].name == '::':
i = 1
# Ignore a :: suffix, if exists.
end = len(token_seq) - 1
if token_seq[end-1].name == '::':
end -= 1
# Make a copy of the sequence so we can append a sentinel
# value. This is required for GetName will has to have some
# terminating condition beyond the last name.
seq_copy = token_seq[i:end]
seq_copy.append(tokenize.Token(tokenize.SYNTAX, '', 0, 0))
names = []
while i < end:
# Iterate through the sequence parsing out each name.
new_name, next = self.GetName(seq_copy[i:])
assert new_name, 'Got empty new_name, next=%s' % next
# We got a pointer or ref. Add it to the name.
if next and next.token_type == tokenize.SYNTAX:
new_name.append(next)
names.append(new_name)
i += len(new_name)
# Now that we have the names, it's time to undo what we did.
# Remove the sentinel value.
names[-1].pop()
# Flatten the token sequence for the return type.
return_type = [e for seq in names[:-1] for e in seq]
# The class name is the last name.
class_name = names[-1]
return return_type, class_name
def handle_bool(self):
pass
def handle_char(self):
pass
def handle_int(self):
pass
def handle_long(self): if token.name == 'default' or token.name == 'delete':
pass # Ignore explicitly defaulted and deleted special members
# in C++11.
token = self._GetNextToken()
else:
# Handle pure-virtual declarations.
assert token.token_type == tokenize.CONSTANT, token
assert token.name == '0', token
modifiers |= FUNCTION_PURE_VIRTUAL
token = self._GetNextToken()
if token.name == '[':
# TODO(nnorwitz): store tokens and improve parsing.
# template <typename T, size_t N> char (&ASH(T (&seq)[N]))[N];
tokens = list(self._GetMatchingChar('[', ']'))
token = self._GetNextToken()
def handle_short(self): assert token.name == ';', (token, return_type_and_name, parameters)
pass
# Looks like we got a method, not a function.
if len(return_type) > 2 and return_type[-1].name == '::':
return_type, in_class = \
self._GetReturnTypeAndClassName(return_type)
return Method(indices.start, indices.end, name.name, in_class,
return_type, parameters, modifiers, templated_types,
body, self.namespace_stack)
return Function(indices.start, indices.end, name.name, return_type,
parameters, modifiers, templated_types, body,
self.namespace_stack)
def _GetReturnTypeAndClassName(self, token_seq):
# Splitting the return type from the class name in a method
# can be tricky. For example, Return::Type::Is::Hard::To::Find().
# Where is the return type and where is the class name?
# The heuristic used is to pull the last name as the class name.
# This includes all the templated type info.
# TODO(nnorwitz): if there is only One name like in the
# example above, punt and assume the last bit is the class name.
# Ignore a :: prefix, if exists so we can find the first real name.
i = 0
if token_seq[0].name == '::':
i = 1
# Ignore a :: suffix, if exists.
end = len(token_seq) - 1
if token_seq[end-1].name == '::':
end -= 1
# Make a copy of the sequence so we can append a sentinel
# value. This is required for GetName will has to have some
# terminating condition beyond the last name.
seq_copy = token_seq[i:end]
seq_copy.append(tokenize.Token(tokenize.SYNTAX, '', 0, 0))
names = []
while i < end:
# Iterate through the sequence parsing out each name.
new_name, next = self.GetName(seq_copy[i:])
assert new_name, 'Got empty new_name, next=%s' % next
# We got a pointer or ref. Add it to the name.
if next and next.token_type == tokenize.SYNTAX:
new_name.append(next)
names.append(new_name)
i += len(new_name)
# Now that we have the names, it's time to undo what we did.
# Remove the sentinel value.
names[-1].pop()
# Flatten the token sequence for the return type.
return_type = [e for seq in names[:-1] for e in seq]
# The class name is the last name.
class_name = names[-1]
return return_type, class_name
def handle_bool(self):
pass
def handle_double(self): def handle_char(self):
pass pass
def handle_float(self): def handle_int(self):
pass pass
def handle_void(self): def handle_long(self):
pass pass
def handle_wchar_t(self): def handle_short(self):
pass pass
def handle_unsigned(self): def handle_double(self):
pass pass
def handle_signed(self): def handle_float(self):
pass pass
def _GetNestedType(self, ctor): def handle_void(self):
name = None pass
name_tokens, token = self.GetName()
if name_tokens:
name = ''.join([t.name for t in name_tokens])
# Handle forward declarations.
if token.token_type == tokenize.SYNTAX and token.name == ';':
return ctor(token.start, token.end, name, None,
self.namespace_stack)
if token.token_type == tokenize.NAME and self._handling_typedef:
self._AddBackToken(token)
return ctor(token.start, token.end, name, None,
self.namespace_stack)
# Must be the type declaration.
fields = list(self._GetMatchingChar('{', '}'))
del fields[-1] # Remove trailing '}'.
if token.token_type == tokenize.SYNTAX and token.name == '{':
next = self._GetNextToken()
new_type = ctor(token.start, token.end, name, fields,
self.namespace_stack)
# A name means this is an anonymous type and the name
# is the variable declaration.
if next.token_type != tokenize.NAME:
return new_type
name = new_type
token = next
# Must be variable declaration using the type prefixed with keyword.
assert token.token_type == tokenize.NAME, token
return self._CreateVariable(token, token.name, name, [], '', None)
def handle_struct(self):
# Special case the handling typedef/aliasing of structs here.
# It would be a pain to handle in the class code.
name_tokens, var_token = self.GetName()
if name_tokens:
next_token = self._GetNextToken()
is_syntax = (var_token.token_type == tokenize.SYNTAX and
var_token.name[0] in '*&')
is_variable = (var_token.token_type == tokenize.NAME and
next_token.name == ';')
variable = var_token
if is_syntax and not is_variable:
variable = next_token
temp = self._GetNextToken()
if temp.token_type == tokenize.SYNTAX and temp.name == '(':
# Handle methods declared to return a struct.
t0 = name_tokens[0]
struct = tokenize.Token(tokenize.NAME, 'struct',
t0.start-7, t0.start-2)
type_and_name = [struct]
type_and_name.extend(name_tokens)
type_and_name.extend((var_token, next_token))
return self._GetMethod(type_and_name, 0, None, False)
assert temp.name == ';', (temp, name_tokens, var_token)
if is_syntax or (is_variable and not self._handling_typedef):
modifiers = ['struct']
type_name = ''.join([t.name for t in name_tokens])
position = name_tokens[0]
return self._CreateVariable(position, variable.name, type_name,
modifiers, var_token.name, None)
name_tokens.extend((var_token, next_token))
self._AddBackTokens(name_tokens)
else:
self._AddBackToken(var_token)
return self._GetClass(Struct, VISIBILITY_PUBLIC, None)
def handle_union(self): def handle_wchar_t(self):
return self._GetNestedType(Union) pass
def handle_enum(self): def handle_unsigned(self):
return self._GetNestedType(Enum) pass
def handle_auto(self): def handle_signed(self):
# TODO(nnorwitz): warn about using auto? Probably not since it pass
# will be reclaimed and useful for C++0x.
pass
def handle_register(self): def _GetNestedType(self, ctor):
pass name = None
name_tokens, token = self.GetName()
if name_tokens:
name = ''.join([t.name for t in name_tokens])
# Handle forward declarations.
if token.token_type == tokenize.SYNTAX and token.name == ';':
return ctor(token.start, token.end, name, None,
self.namespace_stack)
if token.token_type == tokenize.NAME and self._handling_typedef:
self._AddBackToken(token)
return ctor(token.start, token.end, name, None,
self.namespace_stack)
# Must be the type declaration.
fields = list(self._GetMatchingChar('{', '}'))
del fields[-1] # Remove trailing '}'.
if token.token_type == tokenize.SYNTAX and token.name == '{':
next = self._GetNextToken()
new_type = ctor(token.start, token.end, name, fields,
self.namespace_stack)
# A name means this is an anonymous type and the name
# is the variable declaration.
if next.token_type != tokenize.NAME:
return new_type
name = new_type
token = next
# Must be variable declaration using the type prefixed with keyword.
assert token.token_type == tokenize.NAME, token
return self._CreateVariable(token, token.name, name, [], '', None)
def handle_struct(self):
# Special case the handling typedef/aliasing of structs here.
# It would be a pain to handle in the class code.
name_tokens, var_token = self.GetName()
if name_tokens:
next_token = self._GetNextToken()
is_syntax = (var_token.token_type == tokenize.SYNTAX and
var_token.name[0] in '*&')
is_variable = (var_token.token_type == tokenize.NAME and
next_token.name == ';')
variable = var_token
if is_syntax and not is_variable:
variable = next_token
temp = self._GetNextToken()
if temp.token_type == tokenize.SYNTAX and temp.name == '(':
# Handle methods declared to return a struct.
t0 = name_tokens[0]
struct = tokenize.Token(tokenize.NAME, 'struct',
t0.start-7, t0.start-2)
type_and_name = [struct]
type_and_name.extend(name_tokens)
type_and_name.extend((var_token, next_token))
return self._GetMethod(type_and_name, 0, None, False)
assert temp.name == ';', (temp, name_tokens, var_token)
if is_syntax or (is_variable and not self._handling_typedef):
modifiers = ['struct']
type_name = ''.join([t.name for t in name_tokens])
position = name_tokens[0]
return self._CreateVariable(position, variable.name, type_name,
modifiers, var_token.name, None)
name_tokens.extend((var_token, next_token))
self._AddBackTokens(name_tokens)
else:
self._AddBackToken(var_token)
return self._GetClass(Struct, VISIBILITY_PUBLIC, None)
def handle_union(self):
return self._GetNestedType(Union)
def handle_enum(self):
return self._GetNestedType(Enum)
def handle_auto(self):
# TODO(nnorwitz): warn about using auto? Probably not since it
# will be reclaimed and useful for C++0x.
pass
def handle_const(self): def handle_register(self):
pass pass
def handle_inline(self): def handle_const(self):
pass pass
def handle_extern(self): def handle_inline(self):
pass pass
def handle_static(self): def handle_extern(self):
pass pass
def handle_virtual(self): def handle_static(self):
# What follows must be a method. pass
token = token2 = self._GetNextToken()
if token.name == 'inline':
# HACK(nnorwitz): handle inline dtors by ignoring 'inline'.
token2 = self._GetNextToken()
if token2.token_type == tokenize.SYNTAX and token2.name == '~':
return self.GetMethod(FUNCTION_VIRTUAL + FUNCTION_DTOR, None)
assert token.token_type == tokenize.NAME or token.name == '::', token
return_type_and_name, _ = self._GetVarTokensUpToIgnoringTemplates(
tokenize.SYNTAX, '(') # )
return_type_and_name.insert(0, token)
if token2 is not token:
return_type_and_name.insert(1, token2)
return self._GetMethod(return_type_and_name, FUNCTION_VIRTUAL,
None, False)
def handle_volatile(self):
pass
def handle_mutable(self): def handle_virtual(self):
pass # What follows must be a method.
token = token2 = self._GetNextToken()
if token.name == 'inline':
# HACK(nnorwitz): handle inline dtors by ignoring 'inline'.
token2 = self._GetNextToken()
if token2.token_type == tokenize.SYNTAX and token2.name == '~':
return self.GetMethod(FUNCTION_VIRTUAL + FUNCTION_DTOR, None)
assert token.token_type == tokenize.NAME or token.name == '::', token
return_type_and_name, _ = self._GetVarTokensUpToIgnoringTemplates(
tokenize.SYNTAX, '(') # )
return_type_and_name.insert(0, token)
if token2 is not token:
return_type_and_name.insert(1, token2)
return self._GetMethod(return_type_and_name, FUNCTION_VIRTUAL,
None, False)
def handle_volatile(self):
pass
def handle_public(self): def handle_mutable(self):
assert self.in_class pass
self.visibility = VISIBILITY_PUBLIC
def handle_protected(self): def handle_public(self):
assert self.in_class assert self.in_class
self.visibility = VISIBILITY_PROTECTED self.visibility = VISIBILITY_PUBLIC
def handle_private(self): def handle_protected(self):
assert self.in_class assert self.in_class
self.visibility = VISIBILITY_PRIVATE self.visibility = VISIBILITY_PROTECTED
def handle_friend(self): def handle_private(self):
tokens = self._GetTokensUpTo(tokenize.SYNTAX, ';') assert self.in_class
assert tokens self.visibility = VISIBILITY_PRIVATE
t0 = tokens[0]
return Friend(t0.start, t0.end, tokens, self.namespace_stack)
def handle_static_cast(self): def handle_friend(self):
pass tokens = self._GetTokensUpTo(tokenize.SYNTAX, ';')
assert tokens
t0 = tokens[0]
return Friend(t0.start, t0.end, tokens, self.namespace_stack)
def handle_const_cast(self): def handle_static_cast(self):
pass pass
def handle_dynamic_cast(self): def handle_const_cast(self):
pass pass
def handle_reinterpret_cast(self): def handle_dynamic_cast(self):
pass pass
def handle_new(self): def handle_reinterpret_cast(self):
pass pass
def handle_delete(self): def handle_new(self):
tokens = self._GetTokensUpTo(tokenize.SYNTAX, ';') pass
assert tokens
return Delete(tokens[0].start, tokens[0].end, tokens)
def handle_typedef(self): def handle_delete(self):
token = self._GetNextToken() tokens = self._GetTokensUpTo(tokenize.SYNTAX, ';')
if (token.token_type == tokenize.NAME and assert tokens
keywords.IsKeyword(token.name)): return Delete(tokens[0].start, tokens[0].end, tokens)
# Token must be struct/enum/union/class.
method = getattr(self, 'handle_' + token.name) def handle_typedef(self):
self._handling_typedef = True token = self._GetNextToken()
tokens = [method()] if (token.token_type == tokenize.NAME and
self._handling_typedef = False keywords.IsKeyword(token.name)):
# Token must be struct/enum/union/class.
method = getattr(self, 'handle_' + token.name)
self._handling_typedef = True
tokens = [method()]
self._handling_typedef = False
else:
tokens = [token]
# Get the remainder of the typedef up to the semi-colon.
tokens.extend(self._GetTokensUpTo(tokenize.SYNTAX, ';'))
# TODO(nnorwitz): clean all this up.
assert tokens
name = tokens.pop()
indices = name
if tokens:
indices = tokens[0]
if not indices:
indices = token
if name.name == ')':
# HACK(nnorwitz): Handle pointers to functions "properly".
if (len(tokens) >= 4 and
tokens[1].name == '(' and tokens[2].name == '*'):
tokens.append(name)
name = tokens[3]
elif name.name == ']':
# HACK(nnorwitz): Handle arrays properly.
if len(tokens) >= 2:
tokens.append(name)
name = tokens[1]
new_type = tokens
if tokens and isinstance(tokens[0], tokenize.Token):
new_type = self.converter.ToType(tokens)[0]
return Typedef(indices.start, indices.end, name.name,
new_type, self.namespace_stack)
def handle_typeid(self):
pass # Not needed yet.
def handle_typename(self):
pass # Not needed yet.
def _GetTemplatedTypes(self):
result = {}
tokens = list(self._GetMatchingChar('<', '>'))
len_tokens = len(tokens) - 1 # Ignore trailing '>'.
i = 0
while i < len_tokens:
key = tokens[i].name
i += 1
if keywords.IsKeyword(key) or key == ',':
continue
type_name = default = None
if i < len_tokens:
i += 1
if tokens[i-1].name == '=':
assert i < len_tokens, '%s %s' % (i, tokens)
default, unused_next_token = self.GetName(tokens[i:])
i += len(default)
else: else:
tokens = [token] if tokens[i-1].name != ',':
# We got something like: Type variable.
# Get the remainder of the typedef up to the semi-colon. # Re-adjust the key (variable) and type_name (Type).
tokens.extend(self._GetTokensUpTo(tokenize.SYNTAX, ';')) key = tokens[i-1].name
type_name = tokens[i-2]
# TODO(nnorwitz): clean all this up.
assert tokens result[key] = (type_name, default)
name = tokens.pop() return result
indices = name
if tokens: def handle_template(self):
indices = tokens[0] token = self._GetNextToken()
if not indices: assert token.token_type == tokenize.SYNTAX, token
indices = token assert token.name == '<', token
if name.name == ')': templated_types = self._GetTemplatedTypes()
# HACK(nnorwitz): Handle pointers to functions "properly". # TODO(nnorwitz): for now, just ignore the template params.
if (len(tokens) >= 4 and token = self._GetNextToken()
tokens[1].name == '(' and tokens[2].name == '*'): if token.token_type == tokenize.NAME:
tokens.append(name) if token.name == 'class':
name = tokens[3] return self._GetClass(Class, VISIBILITY_PRIVATE, templated_types)
elif name.name == ']': elif token.name == 'struct':
# HACK(nnorwitz): Handle arrays properly. return self._GetClass(Struct, VISIBILITY_PUBLIC, templated_types)
if len(tokens) >= 2: elif token.name == 'friend':
tokens.append(name) return self.handle_friend()
name = tokens[1] self._AddBackToken(token)
new_type = tokens tokens, last = self._GetVarTokensUpTo(tokenize.SYNTAX, '(', ';')
if tokens and isinstance(tokens[0], tokenize.Token): tokens.append(last)
new_type = self.converter.ToType(tokens)[0] self._AddBackTokens(tokens)
return Typedef(indices.start, indices.end, name.name, if last.name == '(':
new_type, self.namespace_stack) return self.GetMethod(FUNCTION_NONE, templated_types)
# Must be a variable definition.
def handle_typeid(self): return None
pass # Not needed yet.
def handle_true(self):
def handle_typename(self): pass # Nothing to do.
pass # Not needed yet.
def handle_false(self):
def _GetTemplatedTypes(self): pass # Nothing to do.
result = {}
tokens = list(self._GetMatchingChar('<', '>')) def handle_asm(self):
len_tokens = len(tokens) - 1 # Ignore trailing '>'. pass # Not needed yet.
i = 0
while i < len_tokens: def handle_class(self):
key = tokens[i].name return self._GetClass(Class, VISIBILITY_PRIVATE, None)
i += 1
if keywords.IsKeyword(key) or key == ',': def _GetBases(self):
continue # Get base classes.
type_name = default = None bases = []
if i < len_tokens: while 1:
i += 1 token = self._GetNextToken()
if tokens[i-1].name == '=': assert token.token_type == tokenize.NAME, token
assert i < len_tokens, '%s %s' % (i, tokens) # TODO(nnorwitz): store kind of inheritance...maybe.
default, unused_next_token = self.GetName(tokens[i:]) if token.name not in ('public', 'protected', 'private'):
i += len(default) # If inheritance type is not specified, it is private.
else: # Just put the token back so we can form a name.
if tokens[i-1].name != ',': # TODO(nnorwitz): it would be good to warn about this.
# We got something like: Type variable.
# Re-adjust the key (variable) and type_name (Type).
key = tokens[i-1].name
type_name = tokens[i-2]
result[key] = (type_name, default)
return result
def handle_template(self):
token = self._GetNextToken()
assert token.token_type == tokenize.SYNTAX, token
assert token.name == '<', token
templated_types = self._GetTemplatedTypes()
# TODO(nnorwitz): for now, just ignore the template params.
token = self._GetNextToken()
if token.token_type == tokenize.NAME:
if token.name == 'class':
return self._GetClass(Class, VISIBILITY_PRIVATE, templated_types)
elif token.name == 'struct':
return self._GetClass(Struct, VISIBILITY_PUBLIC, templated_types)
elif token.name == 'friend':
return self.handle_friend()
self._AddBackToken(token) self._AddBackToken(token)
tokens, last = self._GetVarTokensUpTo(tokenize.SYNTAX, '(', ';') else:
tokens.append(last) # Check for virtual inheritance.
self._AddBackTokens(tokens) token = self._GetNextToken()
if last.name == '(': if token.name != 'virtual':
return self.GetMethod(FUNCTION_NONE, templated_types) self._AddBackToken(token)
# Must be a variable definition.
return None
def handle_true(self):
pass # Nothing to do.
def handle_false(self):
pass # Nothing to do.
def handle_asm(self):
pass # Not needed yet.
def handle_class(self):
return self._GetClass(Class, VISIBILITY_PRIVATE, None)
def _GetBases(self):
# Get base classes.
bases = []
while 1:
token = self._GetNextToken()
assert token.token_type == tokenize.NAME, token
# TODO(nnorwitz): store kind of inheritance...maybe.
if token.name not in ('public', 'protected', 'private'):
# If inheritance type is not specified, it is private.
# Just put the token back so we can form a name.
# TODO(nnorwitz): it would be good to warn about this.
self._AddBackToken(token)
else:
# Check for virtual inheritance.
token = self._GetNextToken()
if token.name != 'virtual':
self._AddBackToken(token)
else:
# TODO(nnorwitz): store that we got virtual for this base.
pass
base, next_token = self.GetName()
bases_ast = self.converter.ToType(base)
assert len(bases_ast) == 1, bases_ast
bases.append(bases_ast[0])
assert next_token.token_type == tokenize.SYNTAX, next_token
if next_token.name == '{':
token = next_token
break
# Support multiple inheritance.
assert next_token.name == ',', next_token
return bases, token
def _GetClass(self, class_type, visibility, templated_types):
class_name = None
class_token = self._GetNextToken()
if class_token.token_type != tokenize.NAME:
assert class_token.token_type == tokenize.SYNTAX, class_token
token = class_token
else: else:
# Skip any macro (e.g. storage class specifiers) after the # TODO(nnorwitz): store that we got virtual for this base.
# 'class' keyword. pass
next_token = self._GetNextToken() base, next_token = self.GetName()
if next_token.token_type == tokenize.NAME: bases_ast = self.converter.ToType(base)
self._AddBackToken(next_token) assert len(bases_ast) == 1, bases_ast
else: bases.append(bases_ast[0])
self._AddBackTokens([class_token, next_token]) assert next_token.token_type == tokenize.SYNTAX, next_token
name_tokens, token = self.GetName() if next_token.name == '{':
class_name = ''.join([t.name for t in name_tokens]) token = next_token
bases = None break
if token.token_type == tokenize.SYNTAX: # Support multiple inheritance.
if token.name == ';': assert next_token.name == ',', next_token
# Forward declaration. return bases, token
return class_type(class_token.start, class_token.end,
class_name, None, templated_types, None, def _GetClass(self, class_type, visibility, templated_types):
self.namespace_stack) class_name = None
if token.name in '*&': class_token = self._GetNextToken()
# Inline forward declaration. Could be method or data. if class_token.token_type != tokenize.NAME:
name_token = self._GetNextToken() assert class_token.token_type == tokenize.SYNTAX, class_token
next_token = self._GetNextToken() token = class_token
if next_token.name == ';': else:
# Handle data # Skip any macro (e.g. storage class specifiers) after the
modifiers = ['class'] # 'class' keyword.
return self._CreateVariable(class_token, name_token.name, next_token = self._GetNextToken()
class_name, if next_token.token_type == tokenize.NAME:
modifiers, token.name, None) self._AddBackToken(next_token)
else: else:
# Assume this is a method. self._AddBackTokens([class_token, next_token])
tokens = (class_token, token, name_token, next_token) name_tokens, token = self.GetName()
self._AddBackTokens(tokens) class_name = ''.join([t.name for t in name_tokens])
return self.GetMethod(FUNCTION_NONE, None) bases = None
if token.name == ':': if token.token_type == tokenize.SYNTAX:
bases, token = self._GetBases() if token.name == ';':
# Forward declaration.
body = None return class_type(class_token.start, class_token.end,
if token.token_type == tokenize.SYNTAX and token.name == '{': class_name, None, templated_types, None,
assert token.token_type == tokenize.SYNTAX, token self.namespace_stack)
assert token.name == '{', token if token.name in '*&':
# Inline forward declaration. Could be method or data.
ast = AstBuilder(self.GetScope(), self.filename, class_name, name_token = self._GetNextToken()
visibility, self.namespace_stack) next_token = self._GetNextToken()
body = list(ast.Generate()) if next_token.name == ';':
# Handle data
if not self._handling_typedef: modifiers = ['class']
token = self._GetNextToken() return self._CreateVariable(class_token, name_token.name,
if token.token_type != tokenize.NAME: class_name,
assert token.token_type == tokenize.SYNTAX, token modifiers, token.name, None)
assert token.name == ';', token
else:
new_class = class_type(class_token.start, class_token.end,
class_name, bases, None,
body, self.namespace_stack)
modifiers = []
return self._CreateVariable(class_token,
token.name, new_class,
modifiers, token.name, None)
else: else:
if not self._handling_typedef: # Assume this is a method.
self.HandleError('non-typedef token', token) tokens = (class_token, token, name_token, next_token)
self._AddBackToken(token) self._AddBackTokens(tokens)
return self.GetMethod(FUNCTION_NONE, None)
return class_type(class_token.start, class_token.end, class_name, if token.name == ':':
bases, templated_types, body, self.namespace_stack) bases, token = self._GetBases()
def handle_namespace(self): body = None
if token.token_type == tokenize.SYNTAX and token.name == '{':
assert token.token_type == tokenize.SYNTAX, token
assert token.name == '{', token
ast = AstBuilder(self.GetScope(), self.filename, class_name,
visibility, self.namespace_stack)
body = list(ast.Generate())
if not self._handling_typedef:
token = self._GetNextToken() token = self._GetNextToken()
# Support anonymous namespaces. if token.token_type != tokenize.NAME:
name = None assert token.token_type == tokenize.SYNTAX, token
if token.token_type == tokenize.NAME: assert token.name == ';', token
name = token.name
token = self._GetNextToken()
self.namespace_stack.append(name)
assert token.token_type == tokenize.SYNTAX, token
# Create an internal token that denotes when the namespace is complete.
internal_token = tokenize.Token(_INTERNAL_TOKEN, _NAMESPACE_POP,
None, None)
internal_token.whence = token.whence
if token.name == '=':
# TODO(nnorwitz): handle aliasing namespaces.
name, next_token = self.GetName()
assert next_token.name == ';', next_token
self._AddBackToken(internal_token)
else: else:
assert token.name == '{', token new_class = class_type(class_token.start, class_token.end,
tokens = list(self.GetScope()) class_name, bases, None,
# Replace the trailing } with the internal namespace pop token. body, self.namespace_stack)
tokens[-1] = internal_token
# Handle namespace with nothing in it. modifiers = []
self._AddBackTokens(tokens) return self._CreateVariable(class_token,
return None token.name, new_class,
modifiers, token.name, None)
def handle_using(self): else:
tokens = self._GetTokensUpTo(tokenize.SYNTAX, ';') if not self._handling_typedef:
assert tokens self.HandleError('non-typedef token', token)
return Using(tokens[0].start, tokens[0].end, tokens) self._AddBackToken(token)
def handle_explicit(self): return class_type(class_token.start, class_token.end, class_name,
assert self.in_class bases, templated_types, body, self.namespace_stack)
# Nothing much to do.
# TODO(nnorwitz): maybe verify the method name == class name. def handle_namespace(self):
# This must be a ctor. token = self._GetNextToken()
return self.GetMethod(FUNCTION_CTOR, None) # Support anonymous namespaces.
name = None
def handle_this(self): if token.token_type == tokenize.NAME:
pass # Nothing to do. name = token.name
token = self._GetNextToken()
def handle_operator(self): self.namespace_stack.append(name)
# Pull off the next token(s?) and make that part of the method name. assert token.token_type == tokenize.SYNTAX, token
pass # Create an internal token that denotes when the namespace is complete.
internal_token = tokenize.Token(_INTERNAL_TOKEN, _NAMESPACE_POP,
None, None)
internal_token.whence = token.whence
if token.name == '=':
# TODO(nnorwitz): handle aliasing namespaces.
name, next_token = self.GetName()
assert next_token.name == ';', next_token
self._AddBackToken(internal_token)
else:
assert token.name == '{', token
tokens = list(self.GetScope())
# Replace the trailing } with the internal namespace pop token.
tokens[-1] = internal_token
# Handle namespace with nothing in it.
self._AddBackTokens(tokens)
return None
def handle_using(self):
tokens = self._GetTokensUpTo(tokenize.SYNTAX, ';')
assert tokens
return Using(tokens[0].start, tokens[0].end, tokens)
def handle_explicit(self):
assert self.in_class
# Nothing much to do.
# TODO(nnorwitz): maybe verify the method name == class name.
# This must be a ctor.
return self.GetMethod(FUNCTION_CTOR, None)
def handle_this(self):
pass # Nothing to do.
def handle_operator(self):
# Pull off the next token(s?) and make that part of the method name.
pass
def handle_sizeof(self): def handle_sizeof(self):
pass pass
def handle_case(self): def handle_case(self):
pass pass
def handle_switch(self): def handle_switch(self):
pass pass
def handle_default(self): def handle_default(self):
token = self._GetNextToken() token = self._GetNextToken()
assert token.token_type == tokenize.SYNTAX assert token.token_type == tokenize.SYNTAX
assert token.name == ':' assert token.name == ':'
def handle_if(self): def handle_if(self):
pass pass
def handle_else(self): def handle_else(self):
pass pass
def handle_return(self): def handle_return(self):
tokens = self._GetTokensUpTo(tokenize.SYNTAX, ';') tokens = self._GetTokensUpTo(tokenize.SYNTAX, ';')
if not tokens: if not tokens:
return Return(self.current_token.start, self.current_token.end, None) return Return(self.current_token.start, self.current_token.end, None)
return Return(tokens[0].start, tokens[0].end, tokens) return Return(tokens[0].start, tokens[0].end, tokens)
def handle_goto(self): def handle_goto(self):
tokens = self._GetTokensUpTo(tokenize.SYNTAX, ';') tokens = self._GetTokensUpTo(tokenize.SYNTAX, ';')
assert len(tokens) == 1, str(tokens) assert len(tokens) == 1, str(tokens)
return Goto(tokens[0].start, tokens[0].end, tokens[0].name) return Goto(tokens[0].start, tokens[0].end, tokens[0].name)
def handle_try(self): def handle_try(self):
pass # Not needed yet. pass # Not needed yet.
def handle_catch(self): def handle_catch(self):
pass # Not needed yet. pass # Not needed yet.
def handle_throw(self): def handle_throw(self):
pass # Not needed yet. pass # Not needed yet.
def handle_while(self): def handle_while(self):
pass pass
def handle_do(self): def handle_do(self):
pass pass
def handle_for(self): def handle_for(self):
pass pass
def handle_break(self): def handle_break(self):
self._IgnoreUpTo(tokenize.SYNTAX, ';') self._IgnoreUpTo(tokenize.SYNTAX, ';')
def handle_continue(self): def handle_continue(self):
self._IgnoreUpTo(tokenize.SYNTAX, ';') self._IgnoreUpTo(tokenize.SYNTAX, ';')
def BuilderFromSource(source, filename): def BuilderFromSource(source, filename):
"""Utility method that returns an AstBuilder from source code. """Utility method that returns an AstBuilder from source code.
Args: Args:
source: 'C++ source code' source: 'C++ source code'
...@@ -1698,64 +1710,64 @@ def BuilderFromSource(source, filename): ...@@ -1698,64 +1710,64 @@ def BuilderFromSource(source, filename):
Returns: Returns:
AstBuilder AstBuilder
""" """
return AstBuilder(tokenize.GetTokens(source), filename) return AstBuilder(tokenize.GetTokens(source), filename)
def PrintIndentifiers(filename, should_print): def PrintIndentifiers(filename, should_print):
"""Prints all identifiers for a C++ source file. """Prints all identifiers for a C++ source file.
Args: Args:
filename: 'file1' filename: 'file1'
should_print: predicate with signature: bool Function(token) should_print: predicate with signature: bool Function(token)
""" """
source = utils.ReadFile(filename, False) source = utils.ReadFile(filename, False)
if source is None: if source is None:
sys.stderr.write('Unable to find: %s\n' % filename) sys.stderr.write('Unable to find: %s\n' % filename)
return return
#print('Processing %s' % actual_filename) #print('Processing %s' % actual_filename)
builder = BuilderFromSource(source, filename) builder = BuilderFromSource(source, filename)
try: try:
for node in builder.Generate(): for node in builder.Generate():
if should_print(node): if should_print(node):
print(node.name) print(node.name)
except KeyboardInterrupt: except KeyboardInterrupt:
return return
except: except:
pass pass
def PrintAllIndentifiers(filenames, should_print): def PrintAllIndentifiers(filenames, should_print):
"""Prints all identifiers for each C++ source file in filenames. """Prints all identifiers for each C++ source file in filenames.
Args: Args:
filenames: ['file1', 'file2', ...] filenames: ['file1', 'file2', ...]
should_print: predicate with signature: bool Function(token) should_print: predicate with signature: bool Function(token)
""" """
for path in filenames: for path in filenames:
PrintIndentifiers(path, should_print) PrintIndentifiers(path, should_print)
def main(argv): def main(argv):
for filename in argv[1:]: for filename in argv[1:]:
source = utils.ReadFile(filename) source = utils.ReadFile(filename)
if source is None: if source is None:
continue continue
print('Processing %s' % filename) print('Processing %s' % filename)
builder = BuilderFromSource(source, filename) builder = BuilderFromSource(source, filename)
try: try:
entire_ast = filter(None, builder.Generate()) entire_ast = filter(None, builder.Generate())
except KeyboardInterrupt: except KeyboardInterrupt:
return return
except: except:
# Already printed a warning, print the traceback and continue. # Already printed a warning, print the traceback and continue.
traceback.print_exc() traceback.print_exc()
else: else:
if utils.DEBUG: if utils.DEBUG:
for ast in entire_ast: for ast in entire_ast:
print(ast) print(ast)
if __name__ == '__main__': if __name__ == '__main__':
main(sys.argv) main(sys.argv)
...@@ -35,11 +35,11 @@ from cpp import utils ...@@ -35,11 +35,11 @@ from cpp import utils
# Preserve compatibility with Python 2.3. # Preserve compatibility with Python 2.3.
try: try:
_dummy = set _dummy = set
except NameError: except NameError:
import sets import sets
set = sets.Set set = sets.Set
_VERSION = (1, 0, 1) # The version of this script. _VERSION = (1, 0, 1) # The version of this script.
# How many spaces to indent. Can set me with the INDENT environment variable. # How many spaces to indent. Can set me with the INDENT environment variable.
...@@ -47,202 +47,199 @@ _INDENT = 2 ...@@ -47,202 +47,199 @@ _INDENT = 2
def _RenderType(ast_type): def _RenderType(ast_type):
"""Renders the potentially recursively templated type into a string. """Renders the potentially recursively templated type into a string.
Args: Args:
ast_type: The AST of the type. ast_type: The AST of the type.
Returns: Returns:
Rendered string and a boolean to indicate whether we have multiple args Rendered string of the type.
(which is not handled correctly).
""" """
has_multiarg_error = False # Add modifiers like 'const'.
# Add modifiers like 'const'. modifiers = ''
modifiers = '' if ast_type.modifiers:
if ast_type.modifiers: modifiers = ' '.join(ast_type.modifiers) + ' '
modifiers = ' '.join(ast_type.modifiers) + ' ' return_type = modifiers + ast_type.name
return_type = modifiers + ast_type.name if ast_type.templated_types:
if ast_type.templated_types: # Collect template args.
# Collect template args. template_args = []
template_args = [] for arg in ast_type.templated_types:
for arg in ast_type.templated_types: rendered_arg = _RenderType(arg)
rendered_arg, e = _RenderType(arg) template_args.append(rendered_arg)
if e: has_multiarg_error = True return_type += '<' + ', '.join(template_args) + '>'
template_args.append(rendered_arg) if ast_type.pointer:
return_type += '<' + ', '.join(template_args) + '>' return_type += '*'
# We are actually not handling multi-template-args correctly. So mark it. if ast_type.reference:
if len(template_args) > 1: return_type += '&'
has_multiarg_error = True return return_type
if ast_type.pointer:
return_type += '*'
if ast_type.reference: def _GenerateArg(source):
return_type += '&' """Strips out comments, default arguments, and redundant spaces from a single argument.
return return_type, has_multiarg_error
Args:
source: A string for a single argument.
def _GetNumParameters(parameters, source):
num_parameters = len(parameters) Returns:
if num_parameters == 1: Rendered string of the argument.
first_param = parameters[0] """
if source[first_param.start:first_param.end].strip() == 'void': # Remove end of line comments before eliminating newlines.
# We must treat T(void) as a function with no parameters. arg = re.sub(r'//.*', '', source)
return 0
return num_parameters # Remove c-style comments.
arg = re.sub(r'/\*.*\*/', '', arg)
# Remove default arguments.
arg = re.sub(r'=.*', '', arg)
# Collapse spaces and newlines into a single space.
arg = re.sub(r'\s+', ' ', arg)
return arg.strip()
def _EscapeForMacro(s):
"""Escapes a string for use as an argument to a C++ macro."""
paren_count = 0
for c in s:
if c == '(':
paren_count += 1
elif c == ')':
paren_count -= 1
elif c == ',' and paren_count == 0:
return '(' + s + ')'
return s
def _GenerateMethods(output_lines, source, class_node): def _GenerateMethods(output_lines, source, class_node):
function_type = (ast.FUNCTION_VIRTUAL | ast.FUNCTION_PURE_VIRTUAL | function_type = (
ast.FUNCTION_OVERRIDE) ast.FUNCTION_VIRTUAL | ast.FUNCTION_PURE_VIRTUAL | ast.FUNCTION_OVERRIDE)
ctor_or_dtor = ast.FUNCTION_CTOR | ast.FUNCTION_DTOR ctor_or_dtor = ast.FUNCTION_CTOR | ast.FUNCTION_DTOR
indent = ' ' * _INDENT indent = ' ' * _INDENT
for node in class_node.body: for node in class_node.body:
# We only care about virtual functions. # We only care about virtual functions.
if (isinstance(node, ast.Function) and if (isinstance(node, ast.Function) and node.modifiers & function_type and
node.modifiers & function_type and not node.modifiers & ctor_or_dtor):
not node.modifiers & ctor_or_dtor): # Pick out all the elements we need from the original function.
# Pick out all the elements we need from the original function. modifiers = 'override'
const = '' if node.modifiers & ast.FUNCTION_CONST:
if node.modifiers & ast.FUNCTION_CONST: modifiers = 'const, ' + modifiers
const = 'CONST_'
num_parameters = _GetNumParameters(node.parameters, source) return_type = 'void'
return_type = 'void' if node.return_type:
if node.return_type: return_type = _EscapeForMacro(_RenderType(node.return_type))
return_type, has_multiarg_error = _RenderType(node.return_type)
if has_multiarg_error: args = []
for line in [ for p in node.parameters:
'// The following line won\'t really compile, as the return', arg = _GenerateArg(source[p.start:p.end])
'// type has multiple template arguments. To fix it, use a', args.append(_EscapeForMacro(arg))
'// typedef for the return type.']:
output_lines.append(indent + line) # Create the mock method definition.
tmpl = '' output_lines.extend([
if class_node.templated_types: '%sMOCK_METHOD(%s, %s, (%s), (%s));' %
tmpl = '_T' (indent, return_type, node.name, ', '.join(args), modifiers)
mock_method_macro = 'MOCK_%sMETHOD%d%s' % (const, num_parameters, tmpl) ])
args = ''
if node.parameters:
# Get the full text of the parameters from the start
# of the first parameter to the end of the last parameter.
start = node.parameters[0].start
end = node.parameters[-1].end
# Remove // comments.
args_strings = re.sub(r'//.*', '', source[start:end])
# Remove /* comments */.
args_strings = re.sub(r'/\*.*\*/', '', args_strings)
# Remove default arguments.
args_strings = re.sub(r'=.*,', ',', args_strings)
args_strings = re.sub(r'=.*', '', args_strings)
# Condense multiple spaces and eliminate newlines putting the
# parameters together on a single line. Ensure there is a
# space in an argument which is split by a newline without
# intervening whitespace, e.g.: int\nBar
args = re.sub(' +', ' ', args_strings.replace('\n', ' '))
# Create the mock method definition.
output_lines.extend(['%s%s(%s,' % (indent, mock_method_macro, node.name),
'%s%s(%s));' % (indent * 3, return_type, args)])
def _GenerateMocks(filename, source, ast_list, desired_class_names): def _GenerateMocks(filename, source, ast_list, desired_class_names):
processed_class_names = set() processed_class_names = set()
lines = [] lines = []
for node in ast_list: for node in ast_list:
if (isinstance(node, ast.Class) and node.body and if (isinstance(node, ast.Class) and node.body and
# desired_class_names being None means that all classes are selected. # desired_class_names being None means that all classes are selected.
(not desired_class_names or node.name in desired_class_names)): (not desired_class_names or node.name in desired_class_names)):
class_name = node.name class_name = node.name
parent_name = class_name parent_name = class_name
processed_class_names.add(class_name) processed_class_names.add(class_name)
class_node = node class_node = node
# Add namespace before the class. # Add namespace before the class.
if class_node.namespace: if class_node.namespace:
lines.extend(['namespace %s {' % n for n in class_node.namespace]) # } lines.extend(['namespace %s {' % n for n in class_node.namespace]) # }
lines.append('') lines.append('')
# Add template args for templated classes. # Add template args for templated classes.
if class_node.templated_types: if class_node.templated_types:
# TODO(paulchang): The AST doesn't preserve template argument order, # TODO(paulchang): The AST doesn't preserve template argument order,
# so we have to make up names here. # so we have to make up names here.
# TODO(paulchang): Handle non-type template arguments (e.g. # TODO(paulchang): Handle non-type template arguments (e.g.
# template<typename T, int N>). # template<typename T, int N>).
template_arg_count = len(class_node.templated_types.keys()) template_arg_count = len(class_node.templated_types.keys())
template_args = ['T%d' % n for n in range(template_arg_count)] template_args = ['T%d' % n for n in range(template_arg_count)]
template_decls = ['typename ' + arg for arg in template_args] template_decls = ['typename ' + arg for arg in template_args]
lines.append('template <' + ', '.join(template_decls) + '>') lines.append('template <' + ', '.join(template_decls) + '>')
parent_name += '<' + ', '.join(template_args) + '>' parent_name += '<' + ', '.join(template_args) + '>'
# Add the class prolog. # Add the class prolog.
lines.append('class Mock%s : public %s {' # } lines.append('class Mock%s : public %s {' # }
% (class_name, parent_name)) % (class_name, parent_name))
lines.append('%spublic:' % (' ' * (_INDENT // 2))) lines.append('%spublic:' % (' ' * (_INDENT // 2)))
# Add all the methods. # Add all the methods.
_GenerateMethods(lines, source, class_node) _GenerateMethods(lines, source, class_node)
# Close the class. # Close the class.
if lines: if lines:
# If there are no virtual methods, no need for a public label. # If there are no virtual methods, no need for a public label.
if len(lines) == 2: if len(lines) == 2:
del lines[-1] del lines[-1]
# Only close the class if there really is a class. # Only close the class if there really is a class.
lines.append('};') lines.append('};')
lines.append('') # Add an extra newline. lines.append('') # Add an extra newline.
# Close the namespace. # Close the namespace.
if class_node.namespace: if class_node.namespace:
for i in range(len(class_node.namespace) - 1, -1, -1): for i in range(len(class_node.namespace) - 1, -1, -1):
lines.append('} // namespace %s' % class_node.namespace[i]) lines.append('} // namespace %s' % class_node.namespace[i])
lines.append('') # Add an extra newline. lines.append('') # Add an extra newline.
if desired_class_names: if desired_class_names:
missing_class_name_list = list(desired_class_names - processed_class_names) missing_class_name_list = list(desired_class_names - processed_class_names)
if missing_class_name_list: if missing_class_name_list:
missing_class_name_list.sort() missing_class_name_list.sort()
sys.stderr.write('Class(es) not found in %s: %s\n' % sys.stderr.write('Class(es) not found in %s: %s\n' %
(filename, ', '.join(missing_class_name_list))) (filename, ', '.join(missing_class_name_list)))
elif not processed_class_names: elif not processed_class_names:
sys.stderr.write('No class found in %s\n' % filename) sys.stderr.write('No class found in %s\n' % filename)
return lines return lines
def main(argv=sys.argv): def main(argv=sys.argv):
if len(argv) < 2: if len(argv) < 2:
sys.stderr.write('Google Mock Class Generator v%s\n\n' % sys.stderr.write('Google Mock Class Generator v%s\n\n' %
'.'.join(map(str, _VERSION))) '.'.join(map(str, _VERSION)))
sys.stderr.write(__doc__) sys.stderr.write(__doc__)
return 1 return 1
global _INDENT global _INDENT
try: try:
_INDENT = int(os.environ['INDENT']) _INDENT = int(os.environ['INDENT'])
except KeyError: except KeyError:
pass pass
except: except:
sys.stderr.write('Unable to use indent of %s\n' % os.environ.get('INDENT')) sys.stderr.write('Unable to use indent of %s\n' % os.environ.get('INDENT'))
filename = argv[1] filename = argv[1]
desired_class_names = None # None means all classes in the source file. desired_class_names = None # None means all classes in the source file.
if len(argv) >= 3: if len(argv) >= 3:
desired_class_names = set(argv[2:]) desired_class_names = set(argv[2:])
source = utils.ReadFile(filename) source = utils.ReadFile(filename)
if source is None: if source is None:
return 1 return 1
builder = ast.BuilderFromSource(source, filename) builder = ast.BuilderFromSource(source, filename)
try: try:
entire_ast = filter(None, builder.Generate()) entire_ast = filter(None, builder.Generate())
except KeyboardInterrupt: except KeyboardInterrupt:
return return
except: except:
# An error message was already printed since we couldn't parse. # An error message was already printed since we couldn't parse.
sys.exit(1) sys.exit(1)
else: else:
lines = _GenerateMocks(filename, source, entire_ast, desired_class_names) lines = _GenerateMocks(filename, source, entire_ast, desired_class_names)
sys.stdout.write('\n'.join(lines)) sys.stdout.write('\n'.join(lines))
if __name__ == '__main__': if __name__ == '__main__':
main(sys.argv) main(sys.argv)
...@@ -29,43 +29,43 @@ from cpp import gmock_class ...@@ -29,43 +29,43 @@ from cpp import gmock_class
class TestCase(unittest.TestCase): class TestCase(unittest.TestCase):
"""Helper class that adds assert methods.""" """Helper class that adds assert methods."""
@staticmethod @staticmethod
def StripLeadingWhitespace(lines): def StripLeadingWhitespace(lines):
"""Strip leading whitespace in each line in 'lines'.""" """Strip leading whitespace in each line in 'lines'."""
return '\n'.join([s.lstrip() for s in lines.split('\n')]) return '\n'.join([s.lstrip() for s in lines.split('\n')])
def assertEqualIgnoreLeadingWhitespace(self, expected_lines, lines): def assertEqualIgnoreLeadingWhitespace(self, expected_lines, lines):
"""Specialized assert that ignores the indent level.""" """Specialized assert that ignores the indent level."""
self.assertEqual(expected_lines, self.StripLeadingWhitespace(lines)) self.assertEqual(expected_lines, self.StripLeadingWhitespace(lines))
class GenerateMethodsTest(TestCase): class GenerateMethodsTest(TestCase):
@staticmethod @staticmethod
def GenerateMethodSource(cpp_source): def GenerateMethodSource(cpp_source):
"""Convert C++ source to Google Mock output source lines.""" """Convert C++ source to Google Mock output source lines."""
method_source_lines = [] method_source_lines = []
# <test> is a pseudo-filename, it is not read or written. # <test> is a pseudo-filename, it is not read or written.
builder = ast.BuilderFromSource(cpp_source, '<test>') builder = ast.BuilderFromSource(cpp_source, '<test>')
ast_list = list(builder.Generate()) ast_list = list(builder.Generate())
gmock_class._GenerateMethods(method_source_lines, cpp_source, ast_list[0]) gmock_class._GenerateMethods(method_source_lines, cpp_source, ast_list[0])
return '\n'.join(method_source_lines) return '\n'.join(method_source_lines)
def testSimpleMethod(self): def testSimpleMethod(self):
source = """ source = """
class Foo { class Foo {
public: public:
virtual int Bar(); virtual int Bar();
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD0(Bar,\nint());', 'MOCK_METHOD(int, Bar, (), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testSimpleConstructorsAndDestructor(self): def testSimpleConstructorsAndDestructor(self):
source = """ source = """
class Foo { class Foo {
public: public:
Foo(); Foo();
...@@ -76,26 +76,26 @@ class Foo { ...@@ -76,26 +76,26 @@ class Foo {
virtual int Bar() = 0; virtual int Bar() = 0;
}; };
""" """
# The constructors and destructor should be ignored. # The constructors and destructor should be ignored.
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD0(Bar,\nint());', 'MOCK_METHOD(int, Bar, (), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testVirtualDestructor(self): def testVirtualDestructor(self):
source = """ source = """
class Foo { class Foo {
public: public:
virtual ~Foo(); virtual ~Foo();
virtual int Bar() = 0; virtual int Bar() = 0;
}; };
""" """
# The destructor should be ignored. # The destructor should be ignored.
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD0(Bar,\nint());', 'MOCK_METHOD(int, Bar, (), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testExplicitlyDefaultedConstructorsAndDestructor(self): def testExplicitlyDefaultedConstructorsAndDestructor(self):
source = """ source = """
class Foo { class Foo {
public: public:
Foo() = default; Foo() = default;
...@@ -105,13 +105,13 @@ class Foo { ...@@ -105,13 +105,13 @@ class Foo {
virtual int Bar() = 0; virtual int Bar() = 0;
}; };
""" """
# The constructors and destructor should be ignored. # The constructors and destructor should be ignored.
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD0(Bar,\nint());', 'MOCK_METHOD(int, Bar, (), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testExplicitlyDeletedConstructorsAndDestructor(self): def testExplicitlyDeletedConstructorsAndDestructor(self):
source = """ source = """
class Foo { class Foo {
public: public:
Foo() = delete; Foo() = delete;
...@@ -121,69 +121,69 @@ class Foo { ...@@ -121,69 +121,69 @@ class Foo {
virtual int Bar() = 0; virtual int Bar() = 0;
}; };
""" """
# The constructors and destructor should be ignored. # The constructors and destructor should be ignored.
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD0(Bar,\nint());', 'MOCK_METHOD(int, Bar, (), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testSimpleOverrideMethod(self): def testSimpleOverrideMethod(self):
source = """ source = """
class Foo { class Foo {
public: public:
int Bar() override; int Bar() override;
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD0(Bar,\nint());', 'MOCK_METHOD(int, Bar, (), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testSimpleConstMethod(self): def testSimpleConstMethod(self):
source = """ source = """
class Foo { class Foo {
public: public:
virtual void Bar(bool flag) const; virtual void Bar(bool flag) const;
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_CONST_METHOD1(Bar,\nvoid(bool flag));', 'MOCK_METHOD(void, Bar, (bool flag), (const, override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testExplicitVoid(self): def testExplicitVoid(self):
source = """ source = """
class Foo { class Foo {
public: public:
virtual int Bar(void); virtual int Bar(void);
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD0(Bar,\nint(void));', 'MOCK_METHOD(int, Bar, (void), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testStrangeNewlineInParameter(self): def testStrangeNewlineInParameter(self):
source = """ source = """
class Foo { class Foo {
public: public:
virtual void Bar(int virtual void Bar(int
a) = 0; a) = 0;
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD1(Bar,\nvoid(int a));', 'MOCK_METHOD(void, Bar, (int a), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testDefaultParameters(self): def testDefaultParameters(self):
source = """ source = """
class Foo { class Foo {
public: public:
virtual void Bar(int a, char c = 'x') = 0; virtual void Bar(int a, char c = 'x') = 0;
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD2(Bar,\nvoid(int a, char c ));', 'MOCK_METHOD(void, Bar, (int a, char c), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testMultipleDefaultParameters(self): def testMultipleDefaultParameters(self):
source = """ source = """
class Foo { class Foo {
public: public:
virtual void Bar( virtual void Bar(
...@@ -195,47 +195,58 @@ class Foo { ...@@ -195,47 +195,58 @@ class Foo {
int const *& rp = aDefaultPointer) = 0; int const *& rp = aDefaultPointer) = 0;
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
"MOCK_METHOD7(Bar,\n" 'MOCK_METHOD(void, Bar, '
"void(int a , char c , const int* const p , const std::string& s , char tab[] , int const *& rp ));", '(int a, char c, const int* const p, const std::string& s, char tab[], int const *& rp), '
self.GenerateMethodSource(source)) '(override));', self.GenerateMethodSource(source))
def testConstDefaultParameter(self): def testMultipleSingleLineDefaultParameters(self):
source = """ source = """
class Foo {
public:
virtual void Bar(int a = 42, int b = 43, int c = 44) = 0;
};
"""
self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD(void, Bar, (int a, int b, int c), (override));',
self.GenerateMethodSource(source))
def testConstDefaultParameter(self):
source = """
class Test { class Test {
public: public:
virtual bool Bar(const int test_arg = 42) = 0; virtual bool Bar(const int test_arg = 42) = 0;
}; };
""" """
expected = 'MOCK_METHOD1(Bar,\nbool(const int test_arg ));' self.assertEqualIgnoreLeadingWhitespace(
self.assertEqualIgnoreLeadingWhitespace( 'MOCK_METHOD(bool, Bar, (const int test_arg), (override));',
expected, self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testConstRefDefaultParameter(self): def testConstRefDefaultParameter(self):
source = """ source = """
class Test { class Test {
public: public:
virtual bool Bar(const std::string& test_arg = "42" ) = 0; virtual bool Bar(const std::string& test_arg = "42" ) = 0;
}; };
""" """
expected = 'MOCK_METHOD1(Bar,\nbool(const std::string& test_arg ));' self.assertEqualIgnoreLeadingWhitespace(
self.assertEqualIgnoreLeadingWhitespace( 'MOCK_METHOD(bool, Bar, (const std::string& test_arg), (override));',
expected, self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testRemovesCommentsWhenDefaultsArePresent(self): def testRemovesCommentsWhenDefaultsArePresent(self):
source = """ source = """
class Foo { class Foo {
public: public:
virtual void Bar(int a = 42 /* a comment */, virtual void Bar(int a = 42 /* a comment */,
char /* other comment */ c= 'x') = 0; char /* other comment */ c= 'x') = 0;
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD2(Bar,\nvoid(int a , char c));', 'MOCK_METHOD(void, Bar, (int a, char c), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testDoubleSlashCommentsInParameterListAreRemoved(self): def testDoubleSlashCommentsInParameterListAreRemoved(self):
source = """ source = """
class Foo { class Foo {
public: public:
virtual void Bar(int a, // inline comments should be elided. virtual void Bar(int a, // inline comments should be elided.
...@@ -243,117 +254,111 @@ class Foo { ...@@ -243,117 +254,111 @@ class Foo {
) const = 0; ) const = 0;
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_CONST_METHOD2(Bar,\nvoid(int a, int b));', 'MOCK_METHOD(void, Bar, (int a, int b), (const, override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testCStyleCommentsInParameterListAreNotRemoved(self): def testCStyleCommentsInParameterListAreNotRemoved(self):
# NOTE(nnorwitz): I'm not sure if it's the best behavior to keep these # NOTE(nnorwitz): I'm not sure if it's the best behavior to keep these
# comments. Also note that C style comments after the last parameter # comments. Also note that C style comments after the last parameter
# are still elided. # are still elided.
source = """ source = """
class Foo { class Foo {
public: public:
virtual const string& Bar(int /* keeper */, int b); virtual const string& Bar(int /* keeper */, int b);
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD2(Bar,\nconst string&(int , int b));', 'MOCK_METHOD(const string&, Bar, (int, int b), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testArgsOfTemplateTypes(self): def testArgsOfTemplateTypes(self):
source = """ source = """
class Foo { class Foo {
public: public:
virtual int Bar(const vector<int>& v, map<int, string>* output); virtual int Bar(const vector<int>& v, map<int, string>* output);
};""" };"""
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD2(Bar,\n' 'MOCK_METHOD(int, Bar, (const vector<int>& v, (map<int, string>* output)), (override));',
'int(const vector<int>& v, map<int, string>* output));', self.GenerateMethodSource(source))
self.GenerateMethodSource(source))
def testReturnTypeWithOneTemplateArg(self): def testReturnTypeWithOneTemplateArg(self):
source = """ source = """
class Foo { class Foo {
public: public:
virtual vector<int>* Bar(int n); virtual vector<int>* Bar(int n);
};""" };"""
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD1(Bar,\nvector<int>*(int n));', 'MOCK_METHOD(vector<int>*, Bar, (int n), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testReturnTypeWithManyTemplateArgs(self): def testReturnTypeWithManyTemplateArgs(self):
source = """ source = """
class Foo { class Foo {
public: public:
virtual map<int, string> Bar(); virtual map<int, string> Bar();
};""" };"""
# Comparing the comment text is brittle - we'll think of something self.assertEqualIgnoreLeadingWhitespace(
# better in case this gets annoying, but for now let's keep it simple. 'MOCK_METHOD((map<int, string>), Bar, (), (override));',
self.assertEqualIgnoreLeadingWhitespace( self.GenerateMethodSource(source))
'// The following line won\'t really compile, as the return\n'
'// type has multiple template arguments. To fix it, use a\n' def testSimpleMethodInTemplatedClass(self):
'// typedef for the return type.\n' source = """
'MOCK_METHOD0(Bar,\nmap<int, string>());',
self.GenerateMethodSource(source))
def testSimpleMethodInTemplatedClass(self):
source = """
template<class T> template<class T>
class Foo { class Foo {
public: public:
virtual int Bar(); virtual int Bar();
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD0_T(Bar,\nint());', 'MOCK_METHOD(int, Bar, (), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testPointerArgWithoutNames(self): def testPointerArgWithoutNames(self):
source = """ source = """
class Foo { class Foo {
virtual int Bar(C*); virtual int Bar(C*);
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD1(Bar,\nint(C*));', 'MOCK_METHOD(int, Bar, (C*), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testReferenceArgWithoutNames(self): def testReferenceArgWithoutNames(self):
source = """ source = """
class Foo { class Foo {
virtual int Bar(C&); virtual int Bar(C&);
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD1(Bar,\nint(C&));', 'MOCK_METHOD(int, Bar, (C&), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
def testArrayArgWithoutNames(self): def testArrayArgWithoutNames(self):
source = """ source = """
class Foo { class Foo {
virtual int Bar(C[]); virtual int Bar(C[]);
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(
'MOCK_METHOD1(Bar,\nint(C[]));', 'MOCK_METHOD(int, Bar, (C[]), (override));',
self.GenerateMethodSource(source)) self.GenerateMethodSource(source))
class GenerateMocksTest(TestCase): class GenerateMocksTest(TestCase):
@staticmethod @staticmethod
def GenerateMocks(cpp_source): def GenerateMocks(cpp_source):
"""Convert C++ source to complete Google Mock output source.""" """Convert C++ source to complete Google Mock output source."""
# <test> is a pseudo-filename, it is not read or written. # <test> is a pseudo-filename, it is not read or written.
filename = '<test>' filename = '<test>'
builder = ast.BuilderFromSource(cpp_source, filename) builder = ast.BuilderFromSource(cpp_source, filename)
ast_list = list(builder.Generate()) ast_list = list(builder.Generate())
lines = gmock_class._GenerateMocks(filename, cpp_source, ast_list, None) lines = gmock_class._GenerateMocks(filename, cpp_source, ast_list, None)
return '\n'.join(lines) return '\n'.join(lines)
def testNamespaces(self): def testNamespaces(self):
source = """ source = """
namespace Foo { namespace Foo {
namespace Bar { class Forward; } namespace Bar { class Forward; }
namespace Baz { namespace Baz {
...@@ -366,96 +371,91 @@ class Test { ...@@ -366,96 +371,91 @@ class Test {
} // namespace Baz } // namespace Baz
} // namespace Foo } // namespace Foo
""" """
expected = """\ expected = """\
namespace Foo { namespace Foo {
namespace Baz { namespace Baz {
class MockTest : public Test { class MockTest : public Test {
public: public:
MOCK_METHOD0(Foo, MOCK_METHOD(void, Foo, (), (override));
void());
}; };
} // namespace Baz } // namespace Baz
} // namespace Foo } // namespace Foo
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(expected,
expected, self.GenerateMocks(source)) self.GenerateMocks(source))
def testClassWithStorageSpecifierMacro(self): def testClassWithStorageSpecifierMacro(self):
source = """ source = """
class STORAGE_SPECIFIER Test { class STORAGE_SPECIFIER Test {
public: public:
virtual void Foo(); virtual void Foo();
}; };
""" """
expected = """\ expected = """\
class MockTest : public Test { class MockTest : public Test {
public: public:
MOCK_METHOD0(Foo, MOCK_METHOD(void, Foo, (), (override));
void());
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(expected,
expected, self.GenerateMocks(source)) self.GenerateMocks(source))
def testTemplatedForwardDeclaration(self): def testTemplatedForwardDeclaration(self):
source = """ source = """
template <class T> class Forward; // Forward declaration should be ignored. template <class T> class Forward; // Forward declaration should be ignored.
class Test { class Test {
public: public:
virtual void Foo(); virtual void Foo();
}; };
""" """
expected = """\ expected = """\
class MockTest : public Test { class MockTest : public Test {
public: public:
MOCK_METHOD0(Foo, MOCK_METHOD(void, Foo, (), (override));
void());
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(expected,
expected, self.GenerateMocks(source)) self.GenerateMocks(source))
def testTemplatedClass(self): def testTemplatedClass(self):
source = """ source = """
template <typename S, typename T> template <typename S, typename T>
class Test { class Test {
public: public:
virtual void Foo(); virtual void Foo();
}; };
""" """
expected = """\ expected = """\
template <typename T0, typename T1> template <typename T0, typename T1>
class MockTest : public Test<T0, T1> { class MockTest : public Test<T0, T1> {
public: public:
MOCK_METHOD0_T(Foo, MOCK_METHOD(void, Foo, (), (override));
void());
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(expected,
expected, self.GenerateMocks(source)) self.GenerateMocks(source))
def testTemplateInATemplateTypedef(self): def testTemplateInATemplateTypedef(self):
source = """ source = """
class Test { class Test {
public: public:
typedef std::vector<std::list<int>> FooType; typedef std::vector<std::list<int>> FooType;
virtual void Bar(const FooType& test_arg); virtual void Bar(const FooType& test_arg);
}; };
""" """
expected = """\ expected = """\
class MockTest : public Test { class MockTest : public Test {
public: public:
MOCK_METHOD1(Bar, MOCK_METHOD(void, Bar, (const FooType& test_arg), (override));
void(const FooType& test_arg));
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(expected,
expected, self.GenerateMocks(source)) self.GenerateMocks(source))
def testTemplateInATemplateTypedefWithComma(self): def testTemplateInATemplateTypedefWithComma(self):
source = """ source = """
class Test { class Test {
public: public:
typedef std::function<void( typedef std::function<void(
...@@ -463,18 +463,33 @@ class Test { ...@@ -463,18 +463,33 @@ class Test {
virtual void Bar(const FooType& test_arg); virtual void Bar(const FooType& test_arg);
}; };
""" """
expected = """\ expected = """\
class MockTest : public Test {
public:
MOCK_METHOD(void, Bar, (const FooType& test_arg), (override));
};
"""
self.assertEqualIgnoreLeadingWhitespace(expected,
self.GenerateMocks(source))
def testParenthesizedCommaInArg(self):
source = """
class Test {
public:
virtual void Bar(std::function<void(int, int)> f);
};
"""
expected = """\
class MockTest : public Test { class MockTest : public Test {
public: public:
MOCK_METHOD1(Bar, MOCK_METHOD(void, Bar, (std::function<void(int, int)> f), (override));
void(const FooType& test_arg));
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(expected,
expected, self.GenerateMocks(source)) self.GenerateMocks(source))
def testEnumType(self): def testEnumType(self):
source = """ source = """
class Test { class Test {
public: public:
enum Bar { enum Bar {
...@@ -483,18 +498,17 @@ class Test { ...@@ -483,18 +498,17 @@ class Test {
virtual void Foo(); virtual void Foo();
}; };
""" """
expected = """\ expected = """\
class MockTest : public Test { class MockTest : public Test {
public: public:
MOCK_METHOD0(Foo, MOCK_METHOD(void, Foo, (), (override));
void());
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(expected,
expected, self.GenerateMocks(source)) self.GenerateMocks(source))
def testEnumClassType(self): def testEnumClassType(self):
source = """ source = """
class Test { class Test {
public: public:
enum class Bar { enum class Bar {
...@@ -503,18 +517,17 @@ class Test { ...@@ -503,18 +517,17 @@ class Test {
virtual void Foo(); virtual void Foo();
}; };
""" """
expected = """\ expected = """\
class MockTest : public Test { class MockTest : public Test {
public: public:
MOCK_METHOD0(Foo, MOCK_METHOD(void, Foo, (), (override));
void());
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(expected,
expected, self.GenerateMocks(source)) self.GenerateMocks(source))
def testStdFunction(self): def testStdFunction(self):
source = """ source = """
class Test { class Test {
public: public:
Test(std::function<int(std::string)> foo) : foo_(foo) {} Test(std::function<int(std::string)> foo) : foo_(foo) {}
...@@ -525,16 +538,15 @@ class Test { ...@@ -525,16 +538,15 @@ class Test {
std::function<int(std::string)> foo_; std::function<int(std::string)> foo_;
}; };
""" """
expected = """\ expected = """\
class MockTest : public Test { class MockTest : public Test {
public: public:
MOCK_METHOD0(foo, MOCK_METHOD(std::function<int (std::string)>, foo, (), (override));
std::function<int (std::string)>());
}; };
""" """
self.assertEqualIgnoreLeadingWhitespace( self.assertEqualIgnoreLeadingWhitespace(expected,
expected, self.GenerateMocks(source)) self.GenerateMocks(source))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment